8000 Inductor C++ Compile Error · Issue #154127 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Inductor C++ Compile Error #154127
Open
Open
@RudeyPunk

Description

@RudeyPunk

🐛 Describe the bug

I started generating a C++ Compile error when compiling a simple function:
“InductorError: CppCompileError: C++ compile error”

I am working with an Anaconda environment:

  • Python : 3.12.9
  • pip3 install torch torchvision torchaudio
  • Pandas, hvpplot, & jupyterlab

The code I am testing is:

import torch

TORCHDYNAMO_VERBOSE=1

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

My Error Stack trace is:

---------------------------------------------------------------------------
InductorError                             Traceback (most recent call last)
Cell In[1], line 10
      8     return a + b
      9 opt_foo1 = torch.compile(foo)
---> 10 print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_dynamo\eval_frame.py:663](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_dynamo/eval_frame.py#line=662), in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    659     raise e.with_traceback(None) from None
    660 except ShortenTraceback as e:
    661     # Failures in the backend likely don't have useful
    662     # data in the TorchDynamo frames, so we strip them out.
--> 663     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    664 finally:
    665     # Restore the dynamic layer stack depth if necessary.
    666     set_eval_frame(None)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\compile_fx.py:760](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/compile_fx.py#line=759), in _compile_fx_inner(gm, example_inputs, **graph_kwargs)
    758     raise
    759 except Exception as e:
--> 760     raise InductorError(e, currentframe()).with_traceback(
    761         e.__traceback__
    762     ) from None
    763 finally:
    764     TritonBundler.end_compile()

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\compile_fx.py:745](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/compile_fx.py#line=744), in _compile_fx_inner(gm, example_inputs, **graph_kwargs)
    743 TritonBundler.begin_compile()
    744 try:
--> 745     mb_compiled_graph = fx_codegen_and_compile(
    746         gm, example_inputs, inputs_to_check, **graph_kwargs
    747     )
    748     assert mb_compiled_graph is not None
    749     mb_compiled_graph._time_taken_ns = time.time_ns() - start_time

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\compile_fx.py:1295](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/compile_fx.py#line=1294), in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, **graph_kwargs)
   1291     from .compile_fx_subproc import _SubprocessFxCompile
   1293     scheme = _SubprocessFxCompile()
-> 1295 return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\compile_fx.py:1197](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/compile_fx.py#line=1196), in _InProcessFxCompile.codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs)
   1184             compiled_fn = AotCodeCompiler.compile(
   1185                 graph,
   1186                 wrapper_code.value,
   (...)   1194                 ],
   1195             )
   1196     else:
-> 1197         compiled_fn = graph.compile_to_module().call
   1199 num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
   1200 metrics.num_bytes_accessed += num_bytes

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\graph.py:2083](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/graph.py#line=2082), in GraphLowering.compile_to_module(self)
   2076 def compile_to_module(self) -> ModuleType:
   2077     with dynamo_timed(
   2078         "GraphLowering.compile_to_module",
   2079         phase_name="code_gen",
   2080         log_pt2_compile_event=True,
   2081         dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
   2082     ):
-> 2083         return self._compile_to_module()

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\graph.py:2130](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/graph.py#line=2129), in GraphLowering._compile_to_module(self)
   2124     trace_structured(
   2125         "inductor_output_code",
   2126         lambda: {"filename": path},
   2127         payload_fn=lambda: wrapper_code.value,
   2128     )
   2129 with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True):
-> 2130     mod = PyCodeCache.load_by_key_path(
   2131         key,
   2132         path,
   2133         linemap=linemap,  # type: ignore[arg-type]
   2134         attrs={**self.constants, **self.torchbind_constants},
   2135     )
   2136 self.cache_key = key
   2137 self.cache_path = path

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\codecache.py:2747](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/codecache.py#line=2746), in PyCodeCache.load_by_key_path(cls, key, path, linemap, attrs)
   2744 if linemap is None:
   2745     linemap = []
-> 2747 mod = _reload_python_module(key, path)
   2749 # unzip into separate lines/nodes lists
   2750 cls.linemaps[path] = list(zip(*linemap))

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\runtime\compile_tasks.py:36](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py#line=35), in _reload_python_module(key, path)
     34 mod.__file__ = path
     35 mod.key = key  # type: ignore[attr-defined]
---> 36 exec(code, mod.__dict__, mod.__dict__)
     37 sys.modules[mod.__name__] = mod
     38 return mod

File [~\AppData\Local\Temp\torchinductor_croda\wo\cwoory2aqk53pzwvty2orkc435qpc5tzio4sq6nvxm6vyfjcbl3w.py:31](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/Temp/torchinductor_croda/wo/cwoory2aqk53pzwvty2orkc435qpc5tzio4sq6nvxm6vyfjcbl3w.py#line=30)
     27 async_compile = AsyncCompile()
     28 empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
---> 31 cpp_fused_add_cos_sin_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'float*'], '''
     32 #include "[C:/Users/croda/AppData/Local/Temp/torchinductor_croda/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h](file:///C:/Users/croda/AppData/Local/Temp/torchinductor_croda/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h)"
     33 extern "C" __declspec(dllexport) void kernel(const float* in_ptr0,
     34                        const float* in_ptr1,
     35                        float* out_ptr0)
     36 {
     37     {
     38         for(int64_t x0=static_cast<int64_t>(0LL); x0<static_cast<int64_t>(100LL); x0+=static_cast<int64_t>(1LL))
     39         {
     40             {
     41                 {
     42                     auto tmp0 = in_ptr0[static_cast<int64_t>(x0)];
     43                     auto tmp2 = in_ptr1[static_cast<int64_t>(x0)];
     44                     auto tmp1 = std::sin(tmp0);
     45                     auto tmp3 = std::cos(tmp2);
     46                     auto tmp4 = decltype(tmp1)(tmp1 + tmp3);
     47                     out_ptr0[static_cast<int64_t>(x0)] = tmp4;
     48                 }
     49             }
     50         }
     51     }
     52 }
     53 ''')
     56 async_compile.wait(globals())
     57 del async_compile

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\async_compile.py:370](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/async_compile.py#line=369), in AsyncCompile.cpp_pybinding(self, argtypes, source_code)
    368 kernel_code_log.info("CPP+Bindings Kerne[l:\n](file:///L:/n)%s", source_code)
    369 if get_compile_threads() <= 1:
--> 370     return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
    371 else:
    372     get_result = CppPythonBindingsCodeCache.load_pybinding_async(
    373         argtypes, source_code, submit_fn=self.submit
    374     )

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\codecache.py:2250](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/codecache.py#line=2249), in CppPythonBindingsCodeCache.load_pybinding(cls, *args, **kwargs)
   2248 @classmethod
   2249 def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any:
-> 2250     return cls.load_pybinding_async(*args, **kwargs)()

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\codecache.py:2242](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/codecache.py#line=2241), in CppPythonBindingsCodeCache.load_pybinding_async.<locals>.future()
   2240 nonlocal result
   2241 if result is None:
-> 2242     result = get_result()
   2243     assert isinstance(result, ModuleType)
   2244 return getattr(result, cls.entry_function)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\codecache.py:2051](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/codecache.py#line=2050), in CppCodeCache.load_async.<locals>.load_fn()
   2049 if future is not None:
   2050     future.result()
-> 2051 result = worker_fn()
   2052 assert result is None
   2053 lib = cls._load_library(binary_path, key)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\codecache.py:2079](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/codecache.py#line=2078), in _worker_compile_cpp(lock_path, cpp_builder)
   2077 with FileLock(lock_path, timeout=LOCK_TIMEOUT):
   2078     if not os.path.exists(cpp_builder.get_target_file_path()):
-> 2079         cpp_builder.build()

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\cpp_builder.py:1601](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/cpp_builder.py#line=1600), in CppBuilder.build(self)
   1598 _create_if_dir_not_exist(_build_tmp_dir)
   1600 build_cmd = self.get_command_line()
-> 1601 run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
   1602 _remove_dir(_build_tmp_dir)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\cpp_builder.py:355](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/cpp_builder.py#line=354), in run_compile_cmd(cmd_line, cwd)
    353 def run_compile_cmd(cmd_line: str, cwd: str) -> None:
    354     with dynamo_timed("compile_file"):
--> 355         _run_compile_cmd(cmd_line, cwd)

File [~\AppData\Local\anaconda3\envs\FU_TAML\Lib\site-packages\torch\_inductor\cpp_builder.py:350](http://localhost:8888/lab/tree/OneDrive%20-%20University%20of%20Central%20Florida/TechArt/TADeepLearning/~/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/_inductor/cpp_builder.py#line=349), in _run_compile_cmd(cmd_line, cwd)
    340     instruction = (
    341         "\n\nOpenMP support not found. Please try one of the following solution[s:\n](file:///S:/n)"
    342         "(1) Set the `CXX` environment variable to a compiler other than Apple clang++/g++ "
   (...)    347         " with `include/omp.h` under it."
    348     )
    349     output += instruction
--> 350 raise exc.CppCompileError(cmd, output) from e

InductorError: CppCompileError: C++ compile error

Command:
cl /I [C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Include](file:///C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Include) /I [C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/include](file:///C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/include) /I [C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/include/torch/csrc/api/include](file:///C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/include/torch/csrc/api/include) /D TORCH_INDUCTOR_CPP_WRAPPER /D STANDALONE_TORCH_HEADER /D C10_USING_CUSTOM_GENERATED_MACROS /DLL /MD /O2 /std:c++20 /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc /openmp /openmp:experimental [C:/Users/croda/AppData/Local/Temp/torchinductor_croda/iz/cizh6kvfsp3ix3iflyvqbgejk2ahpmxnzyv3arfp32jh4546eah6.cpp](file:///C:/Users/croda/AppData/Local/Temp/torchinductor_croda/iz/cizh6kvfsp3ix3iflyvqbgejk2ahpmxnzyv3arfp32jh4546eah6.cpp) /LD /Fe[C:/Users/croda/AppData/Local/Temp/torchinductor_croda/iz/cizh6kvfsp3ix3iflyvqbgejk2ahpmxnzyv3arfp32jh4546eah6.pyd](file:///C:/Users/croda/AppData/Local/Temp/torchinductor_croda/iz/cizh6kvfsp3ix3iflyvqbgejk2ahpmxnzyv3arfp32jh4546eah6.pyd) /link /LIBPATH:[C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/libs](file:///C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/libs) /LIBPATH:[C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/lib](file:///C:/Users/croda/AppData/Local/anaconda3/envs/FU_TAML/Lib/site-packages/torch/lib) torch.lib torch_cpu.lib torch_python.lib sleef.lib

Output:
Microsoft (R) C/C++ Optimizing Compiler Version 19.44.35207.1 for x64
Copyright (C) Microsoft Corporation.  All rights reserved.

cl : Command line warning D9025 : overriding '/openmp' with '/openmp:experimental'
cizh6kvfsp3ix3iflyvqbgejk2ahpmxnzyv3arfp32jh4546eah6.cpp
[C:/Users/croda/AppData/Local/Temp/torchinductor_croda/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h](file:///C:/Users/croda/AppData/Local/Temp/torchinductor_croda/pi/cpicxudqmdsjh5cm4klbtbrvy2cxwr7whxl3md2zzdjdf3orvfdf.h)(3): fatal error C1083: Cannot open include file: 'algorithm': No such file or directory


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

I'm having problems understanding where to go from here.

Versions

Collecting environment information...
PyTorch version: 2.7.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Education (10.0.26100 64-bit)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:49:16) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-11-10.0.26100-SP0
Is CUDA available: False
CUDA runtime version: 12.6.85
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090

Nvidia driver version: 561.17
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: Intel(R) Core(TM) i9-14900KF
Manufacturer: GenuineIntel
Family: 207
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 3200
MaxClockSpeed: 3200
L2CacheSize: 32768
L2CacheSpeed: None
Revision: None

Versions of relevant libraries:
[pip3] numpy==2.0.1
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchvision==0.22.0
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h6b88ed4_46358
[conda] mkl-service 2.4.0 py312h827c3e9_2
[conda] mkl_fft 1.3.11 py312h827c3e9_0
[conda] mkl_random 1.2.8 py312h0158946_0
[conda] numpy 2.0.1 py312hfd52020_1
[conda] numpy-base 2.0.1 py312h4dde369_1
[conda] torch 2.7.0 pypi_0 pypi
[conda] torchaudio 2.7.0 pypi_0 pypi
[conda] torchvision 0.22.0 pypi_0 pypi

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm @iremyux @Blackhex @chauhang @penguinwu

Metadata

Metadata

Assignees

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0