Source code for mbrs.utils
1from typing import Any
2
3import torch
4
5
[docs]
6def to_device(sample: Any, device: torch.device):
7 def _to_device(x):
8 if torch.is_tensor(x):
9 return x.to(device=device, non_blocking=True)
10 elif isinstance(x, dict):
11 return {key: _to_device(value) for key, value in x.items()}
12 elif isinstance(x, list):
13 return [_to_device(x) for x in x]
14 elif isinstance(x, tuple):
15 return tuple(_to_device(x) for x in x)
16 elif isinstance(x, set):
17 return {_to_device(x) for x in x}
18 else:
19 return x
20
21 return _to_device(sample)