8000 Possible data race in PyOperation and ~PyOperation on cached Module · Issue #28551 · jax-ml/jax · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Possible data race in PyOperation and ~PyOperation on cached Module #28551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vfdev-5 opened this issue May 6, 2025 · 2 comments · May be fixed by llvm/llvm-project#139721
Open

Possible data race in PyOperation and ~PyOperation on cached Module #28551

vfdev-5 opened this issue May 6, 2025 · 2 comments · May be fixed by llvm/llvm-project#139721
Labels
bug Something isn't working free threading Issues found in free threading builds

Comments

@vfdev-5
Copy link
Collaborator
vfdev-5 commented May 6, 2025

Description

Reproducer code with JAX:

from jax import export

import concurrent.futures
import threading


class CustomTest(jtu.JaxTestCase):
  def test_1(self):
    num_workers = 40
    num_runs = 20

    barrier = threading.Barrier(num_workers)

    def closure():
        barrier.wait()

        func = jnp.sin
        data_inputs = (np.array(0.0, dtype=np.float32),)
        polymorphic_shapes = None

        args_specs = export.symbolic_args_specs(data_inputs, polymorphic_shapes)

        for _ in range(num_runs):
            exported = export.export(
                jax.jit(func),
                platforms=("cpu",),
                disabled_checks=()
            )(*args_specs)

    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = []
        for i in range(num_workers):
            futures.append(executor.submit(closure))
        assert len(list(f.result() for f in futures)) == num_workers

I started work on it fixing on MLIR python binding: llvm/llvm-project#130612 but we may need to reconsider the fix.

Data race report
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    #2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    #3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
    #4 nanobind::detail::nb_func_vectorcall_simple(_object*, _object* const*, unsigned long, _object*) /proc/self/cwd/external/nanobind/src/n
8000
b_func.cpp:915:26 (libjax_common.so+0x3274482) (BuildId: 55242ad732cdae54)
    #5 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f18b2) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #6 PyObject_CallOneArg /project/cpython/Objects/call.c:395:12 (python3.14+0x1f18b2)
    #7 property_descr_get /project/cpython/Objects/descrobject.c:1695:12 (python3.14+0x20a89f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #8 _PyObject_GenericGetAttrWithDict /project/cpython/Objects/object.c (python3.14+0x2a8da2) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #9 PyObject_GenericGetAttr /project/cpython/Objects/object.c:1792:12 (python3.14+0x2a8732) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #10 PyObject_GetAttr /project/cpython/Objects/object.c:1296:18 (python3.14+0x2a70e7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #11 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:7712:30 (python3.14+0x410d88) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #12 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #13 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #14 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #15 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f6460) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #16 method_vectorcall /project/cpython/Objects/classobject.c:94:18 (python3.14+0x1f6460)
    #17 _PyVectorcall_Call /project/cpython/Objects/call.c:273:16 (python3.14+0x1f171f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #18 _PyObject_Call /project/cpython/Objects/call.c:348:16 (python3.14+0x1f171f)
    #19 PyObject_Call /project/cpython/Objects/call.c:373:12 (python3.14+0x1f1785) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #20 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:2449:32 (python3.14+0x400652) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #21 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #22 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #23 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #24 _PyObject_VectorcallDictTstate /project/cpython/Objects/call.c:135:15 (python3.14+0x1f065d) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #25 _PyObject_Call_Prepend /project/cpython/Objects/call.c:504:24 (python3.14+0x1f20d7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #26 call_method /project/cpython/Objects/typeobject.c:2927:19 (python3.14+0x30e7f6) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #27 slot_tp_call /project/cpython/Objects/typeobject.c:10150:12 (python3.14+0x30e641) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #28 _PyObject_MakeTpCall /project/cpython/Objects/call.c:242:18 (python3.14+0x1f08c8) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #29 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:167:16 (python3.14+0x1f14e8) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #30 PyObject_Vectorcall /project/cpython/Objects/call.c:327:12 (python3.14+0x1f14e8)
    #31 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:3838:35 (python3.14+0x40595e) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #32 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #33 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #34 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #35 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f63af) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #36 method_vectorcall /project/cpython/Objects/classobject.c:72:20 (python3.14+0x1f63af)
    #37 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x4500f7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #38 context_run /project/cpython/Python/context.c:728:29 (python3.14+0x4500f7)
    #39 method_vectorcall_FASTCALL_KEYWORDS /project/cpython/Objects/descrobject.c:421:24 (python3.14+0x209149) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #40 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f142a) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #41 PyObject_Vectorcall /project/cpython/Objects/call.c:327:12 (python3.14+0x1f142a)
    #42 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:1443:35 (python3.14+0x3fc9bd) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #43 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #44 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #45 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #46 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f63af) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #47 method_vectorcall /project/cpython/Objects/classobject.c:72:20 (python3.14+0x1f63af)
    #48 _PyVectorcall_Call /project/cpython/Objects/call.c:273:16 (python3.14+0x1f171f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #49 _PyObject_Call /project/cpython/Objects/call.c:348:16 (python3.14+0x1f171f)
    #50 PyObject_Call /project/cpython/Objects/call.c:373:12 (python3.14+0x1f1785) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #51 thread_run /project/cpython/./Modules/_threadmodule.c:353:21 (python3.14+0x59ee32) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #52 pythread_wrapper /project/cpython/Python/thread_pthread.h:242:5 (python3.14+0x4f87a7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    #2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    #3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    #4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    #5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    #7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
    #8 Py_XDECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:526:9 (libjax_common.so+0x41aaadc)
    #9 nanobind::handle::dec_ref() const & /proc/self/cwd/external/nanobind/include/nanobind/nb_types.h:196:9 (libjax_common.so+0x41aaadc)
    #10 nanobind::object::~object() /proc/self/cwd/external/nanobind/include/nanobind/nb_types.h:218:17 (libjax_common.so+0x41aaadc)
    #11 mlir::python::PyOpView::~PyOpView() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:762:7 (libjax_common.so+0x41aaadc)
    #12 void nanobind::detail::wrap_destruct<mlir::python::PyOpView>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ef9a1) (BuildId: 55242ad732cdae54)
    #13 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    #14 subtype_dealloc /project/cpython/Objects/typeobject.c:2668:5 (python3.14+0x2fc5a5) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #15 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #16 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    #17 Py_DECREF /project/cpython/./Include/refcount.h:387:13 (python3.14+0x46d57e) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #18 _PyFrame_ClearLocals /project/cpython/Python/frame.c:99:9 (python3.14+0x46d57e)
    #19 _PyFrame_ClearExceptCode /project/cpython/Python/frame.c:124:5 (python3.14+0x46d7d2) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #20 clear_thread_frame /project/cpython/Python/ceval.c:1738:5 (python3.14+0x41e027) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #21 _PyEval_FrameClearAndPop /project/cpython/Python/ceval.c:1762:9 (python3.14+0x41e027)
    #22 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:10307:13 (python3.14+0x418786) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #23 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #24 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #25 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #26 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f6460) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #27 method_vectorcall /project/cpython/Objects/classobject.c:94:18 (python3.14+0x1f6460)
    #28 _PyVectorcall_Call /project/cpython/Objects/call.c:273:16 (python3.14+0x1f171f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #29 _PyObject_Call /project/cpython/Objects/call.c:348:16 (python3.14+0x1f171f)
    #30 PyObject_Call /project/cpython/Objects/call.c:373:12 (python3.14+0x1f1785) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #31 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:2449:32 (python3.14+0x400652) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #32 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #33 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #34 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #35 _PyObject_VectorcallDictTstate /project/cpython/Objects/call.c:135:15 (python3.14+0x1f065d) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #36 _PyObject_Call_Prepend /project/cpython/Objects/call.c:504:24 (python3.14+0x1f20d7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #37 call_method /project/cpython/Objects/typeobject.c:2927:19 (python3.14+0x30e7f6) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #38 slot_tp_call /project/cpython/Objects/typeobject.c:10150:12 (python3.14+0x30e641) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #39 _PyObject_MakeTpCall /project/cpython/Objects/call.c:242:18 (python3.14+0x1f08c8) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #40 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:167:16 (python3.14+0x1f14e8) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #41 PyObject_Vectorcall /project/cpython/Objects/call.c:327:12 (python3.14+0x1f14e8)
    #42 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:3838:35 (python3.14+0x40595e) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #43 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #44 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #45 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #46 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f63af) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #47 method_vectorcall /project/cpython/Objects/classobject.c:72:20 (python3.14+0x1f63af)
    #48 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x4500f7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #49 context_run /project/cpython/Python/context.c:728:29 (python3.14+0x4500f7)
    #50 method_vectorcall_FASTCALL_KEYWORDS /project/cpython/Objects/descrobject.c:421:24 (python3.14+0x209149) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #51 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f142a) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #52 PyObject_Vectorcall /project/cpython/Objects/call.c:327:12 (python3.14+0x1f142a)
    #53 _PyEval_EvalFrameDefault /project/cpython/Python/generated_cases.c.h:1443:35 (python3.14+0x3fc9bd) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #54 _PyEval_EvalFrame /project/cpython/./Include/internal/pycore_ceval.h:119:16 (python3.14+0x3f8190) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #55 _PyEval_Vector /project/cpython/Python/ceval.c:1913:12 (python3.14+0x3f8190)
    #56 _PyFunction_Vectorcall /project/cpython/Objects/call.c (python3.14+0x1f1a8f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #57 _PyObject_VectorcallTstate /project/cpython/./Include/internal/pycore_call.h:169:11 (python3.14+0x1f63af) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #58 method_vectorcall /project/cpython/Objects/classobject.c:72:20 (python3.14+0x1f63af)
    #59 _PyVectorcall_Call /project/cpython/Objects/call.c:273:16 (python3.14+0x1f171f) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #60 _PyObject_Call /project/cpython/Objects/call.c:348:16 (python3.14+0x1f171f)
    #61 PyObject_Call /project/cpython/Objects/call.c:373:12 (python3.14+0x1f1785) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #62 thread_run /project/cpython/./Modules/_threadmodule.c:353:21 (python3.14+0x59ee32) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    #63 pythread_wrapper /project/cpython/Python/thread_pthread.h:242:5 (python3.14+0x4f87a7) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)

System info (python version, jaxlib version, accelerator, etc.)

Linux
CPU
3.14-FT

@vfdev-5 vfdev-5 added bug Something isn't working free threading Issues found in free threading builds labels May 6, 2025
@hawkinsp
Copy link
Collaborator

Simpler repro that does not involve JAX (except as a way to get MLIR):

from jax._src.lib.mlir import ir

import io
import concurrent.futures
import threading

num_workers = 40
num_runs = 20

barrier = threading.Barrier(num_workers)

aglobal = None

def module_to_bytecode(module: ir.Module) -> bytes:
  output = io.BytesIO()
  module.operation.write_bytecode(file=output)
  return output.getvalue()


def closure():
    global aglobal
    barrier.wait()

    for _ in range(num_runs):
        with ir.Context():
            loc = ir.Location.unknown()
            m = ir.Module.create(loc=loc)

        aglobal = m
        module_to_bytecode(aglobal)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
    for _ in range(100):
        futures = []
        for i in range(num_workers):
            futures.append(executor.submit(closure))
        assert len(list(f.result() for f in futures)) == num_workers

@vfdev-5
Copy link
Collaborator Author
vfdev-5 commented May 12, 2025

In this gist there's another mlir python binding reproducer: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da#file-repro_2-py

hawkinsp added a commit to hawkinsp/llvm-project that referenced this issue May 13, 2025
Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, when incrementing the reference count, we must use the
Python 3.14+ API `PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.
hawkinsp added a commit to hawkinsp/llvm-project that referenced this issue May 13, 2025
Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.

Fixes jax-ml/jax#28551
hawkinsp added a commit to hawkinsp/llvm-project that referenced this issue May 13, 2025
Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.

Fixes jax-ml/jax#28551
hawkinsp added a commit to hawkinsp/llvm-project that referenced this issue May 13, 2025
Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.

Fixes jax-ml/jax#28551
hawkinsp added a commit to hawkinsp/llvm-project that referenced this issue May 13, 2025
Joint work with @vfdev-5

We found the following TSAN race report in JAX's CI:
jax-ml/jax#28551

```
WARNING: ThreadSanitizer: data race (pid=35893)
  Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
    #0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
    #1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
    llvm#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
    llvm#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...

  Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
    #0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
    #1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-
69FB
project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
    llvm#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
    llvm#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
    llvm#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
    llvm#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
    llvm#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
    llvm#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```

At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.

The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
  decides to release it.
* After T56 starts to release its reference, but before T56 removes the
  PyOperation from the liveOperations map a second thread T57 comes
  along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
  reference count of that PyOperation and returns it to the caller.
  This is illegal! Python is in the process of calling the destructor of
  that object, and once an object is in that state it cannot be safely
  revived.

To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (python/cpython#128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.

Fixes jax-ml/jax#28551
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working free threading Issues found in free threading builds
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants
0