Open
Description
Often it is needed to move model results to cpu (or inputs to gpu). Once the data structures get a bit complicated, dicts and lists appear often in model results. Often we have to roll a little utility method like below. If indeed other people had to write this sort of utilities, may make sense to include something like this to core.
def to(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, dict):
return {k : to(v, device) for k, v in obj.items()}
if isinstance(obj, tuple):
return tuple(to(v, device) for v in obj)
if isinstance(obj, list):
return [to(v, device) for v in obj]
return obj