Source code for mbrs.utils

 1from typing import Any
 2
 3import torch
 4
 5
[docs] 6def to_device(sample: Any, device: torch.device): 7 def _to_device(x): 8 if torch.is_tensor(x): 9 return x.to(device=device, non_blocking=True) 10 elif isinstance(x, dict): 11 return {key: _to_device(value) for key, value in x.items()} 12 elif isinstance(x, list): 13 return [_to_device(x) for x in x] 14 elif isinstance(x, tuple): 15 return tuple(_to_device(x) for x in x) 16 elif isinstance(x, set): 17 return {_to_device(x) for x in x} 18 else: 19 return x 20 21 return _to_device(sample)