10000 MPS Backend Error: ComplexDouble (complex128) Conversion Fails When Diffusers Transformer Creates 64‐bit Complex Tensors · Issue #148670 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
MPS Backend Error: ComplexDouble (complex128) Conversion Fails When Diffusers Transformer Creates 64‐bit Complex Tensors #148670
Open
@mozzipa

Description

@mozzipa

🐛 Describe the bug

When running a diffusers-based transformer pipeline (e.g., the WanPipeline from diffusers) on Apple’s MPS device, an error is raised because a tensor is being converted to a ComplexDouble (i.e. torch.complex128) type. The error message is:

TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype.
Since the MPS backend does not support double‐precision (64‑bit) real or complex types (only torch.float32 and its corresponding complex type, torch.cfloat, are supported), this error prevents the pipeline from running. The root cause appears to be that part of the transformer’s rotary embedding computation is using float64—leading to a ComplexDouble output—when run on MPS.

Steps to Reproduce:

Use diffusers (e.g., version 0.33.0.dev0) along with a recent nightly build of PyTorch (e.g., 2.7.0.dev20250305) on an Apple Silicon machine with MPS enabled.
Load a pipeline such as:

from diffusers import AutoencoderKLWan, WanPipeline
vae = AutoencoderKLWan.from_pretrained("<model_path>", subfolder="vae", torch_dtype=torch.float32).to("mps")
pipe = WanPipeline.from_pretrained("<model_path>", vae=vae, torch_dtype=torch.float32).to("mps")

Run inference (e.g., call the pipeline with prompt embeddings), which triggers the transformer’s rotary embedding function.
The error occurs when torch.view_as_complex is called on a tensor that was computed as float64, resulting in an unsupported complex128 tensor.
Expected Behavior:
All operations on the MPS device should use supported dtypes. In particular, any complex-valued computation should use torch.cfloat (complex64) rather than torch.cdouble (complex128). An ideal solution would either (a) automatically downcast any double-precision inputs when on MPS or (b) warn and allow developers to control the dtype.

Workaround:
A temporary workaround is to monkey-patch torch.view_as_complex so that on MPS, if the input is float64 it is first cast to float32 before conversion. For example:

_orig_view_as_complex = torch.view_as_complex
def patched_view_as_complex(tensor):
    if tensor.device.type == "mps" and tensor.dtype == torch.float64:
        tensor = tensor.to(torch.float32)
    return _orig_view_as_complex(tensor)
torch.view_as_complex = patched_view_as_complex

Environment Details:

PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS (Apple Silicon, MPS enabled)
diffusers: 0.33.0.dev0
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS

Versions

PyTorch: 2.7.0.dev20250305 (nightly)
OS: macOS (Apple Silicon, MPS enabled)
diffusers: 0.33.0.dev0
Other libraries: torchaudio 2.6.0.dev20250305, torchvision 0.22.0.dev20250305
Device: MPS

cc @ezyang @anjali411 @dylanbespalko @mruberry @nikitaved @amjames @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureA request for a proper, new feature.module: complexRelated to complex number support in PyTorchmodule: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0