8000 [feature request] `torch.to(obj, device, dtype)` supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting `torch.distributed.utils._recursive_to` · Issue #69431 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[feature request] torch.to(obj, device, dtype) supporting recursive lists/dicts/tuples of tensors probably by uplifting/promoting torch.distributed.utils._recursive_to #69431
Open
@vadimkantorov

Description

@vadimkantorov

Often it is needed to move model results to cpu (or inputs to gpu). Once the data structures get a bit complicated, dicts and lists appear often in model results. Often we have to roll a little utility method like below. If indeed other people had to write this sort of utilities, may make sense to include something like this to core.

def to(obj, device):
  if torch.is_tensor(obj):
    return obj.to(device)
  if isinstance(obj, dict):
    return {k : to(v, device) for k, v in obj.items()}
  if isinstance(obj, tuple):
    return tuple(to(v, device) for v in obj)
  if isinstance(obj, list):
    return [to(v, device) for v in obj]
  return obj

cc @albanD @mruberry @jbschlosser @walterddr

Metadata

Metadata

Assignees

No one assigned

    Labels

    function requestA request for a new function or the addition of new arguments/modes to an existing function.needs researchWe need to decide whether or not this merits inclusion, based on research worldtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0