-
Notifications
You must be signed in to change notification settings - Fork 3k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
94 Pull requests merged by 14 people
-
[pallas:mosaic_gpu] Removed debug prints from
emit_pipeline_warp_specialized
#28707 merged
May 13, 2025 -
Update tridiagonal solve kernels on GPU to properly use the FFI.
#28696 merged
May 13, 2025 -
[platform_dependent] Ensure that platform_dependent only lowers for intended platforms
#28607 merged
May 13, 2025 -
Add
block_until_ready()
to FAQ code snippet.#28700 merged
May 12, 2025 -
Support pltpu.roll on sublanes when not all lanes are used.
#28228 merged
May 12, 2025 -
Move the existing mask handling code to the relayout fn, invoke it from the existing tpu relayout rule.
#28682 merged
May 12, 2025 -
[Mosaic] Support squeezing tiled memrefs to 1d shapes.
#28444 merged
May 12, 2025 -
[array API] update test suite to most recent version
#28695 merged
May 12, 2025 -
Speed up
scipy.signal.stft
by usinglax.dynamic_slice_in_dim
for windowing#28662 merged
May 12, 2025 -
Fix debug rule in .bazelrc
#28683 merged
May 12, 2025 -
[pallas:mosaic_gpu] Slightly generalized
MosaicGridMapping
#28644 merged
May 12, 2025 -
[Mosaic GPU] Add support for TMEM loads/stores with the 32x32b shape
#28641 merged
May 12, 2025 -
[Mosaic GPU] Use f8e4m3fn in place of f8e4m3
#28685 merged
May 12, 2025 -
Call block_until_ready for testAutodiffCache
#28674 merged
May 12, 2025 -
Add collective_axes to run_scoped
#28546 merged
May 12, 2025 -
[Mosaic GPU] Add an additional WG barrier before copy_gmem_to_smem
#28681 merged
May 12, 2025 -
[Pallas:MGPU] Update one more lowering rule to the load/store rename in TMEMRef
#28680 merged
May 12, 2025 -
[Mosaic GPU] Add support for 8-bit MMA on Blackwell
#28647 merged
May 12, 2025 -
[si_vjp] fix bugs around symbolic zeros
#28664 merged
May 10, 2025 -
Automated Code Change
#28668 merged
May 10, 2025 -
[Pallas] Allow more int casting tests.
#28494 merged
May 10, 2025 -
Fix typo in FusedAttentionTest
#28660 merged
May 9, 2025 -
Remove __div__ and __rdiv__ from jax.Array
#28626 merged
May 9, 2025 -
Disable profiler tests under Python 3.14 if multithreaded.
#28656 merged
May 9, 2025 -
pytest: use importlib mode by default
#28650 merged
May 9, 2025 -
[Mosaic] Fix typo: FPToSI > SIToFP.
#28655 merged
May 9, 2025 -
Reverts 8137c37e324c9cb5c8f991a16d78310b6e37bd05
#28654 merged
May 9, 2025 -
Use cython from pypi in tsan CI build.
#28642 merged
May 9, 2025 -
[Mosaic] Move sitofp lowering to Mosaic.
#28625 merged
May 9, 2025 -
Add
.update
to ShapeDtypeStruct#28634 merged
May 9, 2025 -
Split custom_* tests out of api_test into new target.
#28648 merged
May 9, 2025 -
Add a pretty printing rule for custom_jvp
#28608 merged
May 9, 2025 -
[Pallas/TPU]
#28628 merged
May 9, 2025 -
[Mosaic GPU] Use explicit load/store methods instead of __getitem__/__setitem__
#28580 merged
May 9, 2025 -
Add supports_pinned_allocator to allow debugging pinning issues.
#28633 merged
May 9, 2025 -
Add host offloading docs to public website
#28632 merged
May 9, 2025 -
Fix empty string handling for cloud_tpu_cluster
#28554 merged
May 9, 2025 -
Simplify add's unreduced rule. Only propagate unreduced if both lhs and rhs are unreduced.
#28629 merged
May 9, 2025 -
Remove type annotation of get_gpu_client
#28620 merged
May 8, 2025 -
Fix the PyTest TPU jobs on the Continuous Wheel Tests workflow.
#28622 merged
May 8, 2025 -
Allow unreduced propagation only for
add
right now.#28624 merged
May 8, 2025 -
[Pallas][Mosaic GPU] Add transpose support to tcgen05_mma
#28528 merged
May 8, 2025 -
Shut down
PreemptionSyncManager
whenjax.distributed.shutdown()
is called.#28342 merged
May 8, 2025 -
[jaxlib] Add compile_and_load, compile_and_load_ifrt_program to xla_client stub.
#28613 merged
May 8, 2025 -
Make the type checker match the runtime behavior of PartitionSpec not inherting from a tuple.
#28619 merged
May 8, 2025 -
jnp.linalg.matrix_power: support non-float inputs
#28612 merged
May 8, 2025 -
jnp.put: check inplace before other conditions
#28610 merged
May 8, 2025 -
[Pallas][Mosaic GPU] Add collective support to Blackwell/tcgen05 MMA.
#28524 merged
May 8, 2025 -
Add
out_sharding
tojnp.repeat
. Drop into auto mode if out_sharding is provided.#28600 merged
May 8, 2025 -
Add argnums param to fwd_and_bwd
#28279 merged
May 8, 2025 -
Fix handling of final style primitives in pallas cost estimate.
#28605 merged
May 8, 2025 -
[Pallas][Mosaic GPU] Expand TMEM support.
#28495 merged
May 8, 2025 -
Upgrade Mac CI builds to run on the Sequoia pool and Apple Clang 17
#28611 merged
May 8, 2025 -
Record event start time in
dispatch.LogElapsedTimeContextManager
.#28590 merged
May 8, 2025 -
[xla::PyClient] Update PyClient to use xla::ifrt::CompileAndLoad.
#28522 merged
May 8, 2025 -
Fix docs build by constraining snowballstemmer version
#28606 merged
May 8, 2025 -
Reverts 6d1b5271a115007162e9f98561d6b118aa66382c
#28601 merged
May 8, 2025 -
jax.scipy.signal.istft: support array input for window
#28584 merged
May 8, 2025 -
jnp.packbits: fix handling of negative entries
#28593 merged
May 8, 2025 -
Fix
with_sharding_constraint
with a scalar input#28599 merged
May 8, 2025 -
[Mosaic] Add explicit control over core parallelization strategy
#28393 merged
May 8, 2025 -
Clean up
PyLoadedExecutable::Delete
#28592 merged
May 7, 2025 -
[pallas:mosaic] Use
cf.assert
directly in the lowering rule forcheckify.check_p
#28577 merged
May 7, 2025 -
Replace
std::shared_ptr<xla::ifrt::LoadedExecutable>
withxla::ifrt::LoadedExecutableRef
#28586 merged
May 7, 2025 -
expose mutable_array in experimental
#28532 merged
May 7, 2025 -
Use
ArrayRef
instead oftsl::RCReference<Array>
#28585 merged
May 7, 2025 -
Declare
tpu.vector_load
, mirroringtpu.vector_store
.#28289 merged
May 7, 2025 -
Clean up
LoadedExecutable::Delete
andLoadedExecutable::IsDeleted
#28541 merged
May 7, 2025 -
Do constant folding and forwarding while tracing instead of as a separate pass
#28396 merged
May 7, 2025 -
[Mosaic GPU] Run
canonicalize
instead ofcse
before the lowering.#28557 merged
May 7, 2025 -
[Mosaic GPU] Do not shortcut the transform computation for
memref.cast
#28576 merged
May 7, 2025 -
[pallas:mosaic] Fixed the type of
dimension_semantics
#28575 merged
May 7, 2025 -
Make PartitionSpec not inherit from a tuple at runtime. For type checkers, it's still a tuple.
#28567 merged
May 7, 2025 -
Block on svd result to fix race condition in svd_test.
#28560 merged
May 7, 2025 -
Fix an import path to properly detect the CUDA plugin in bazel tests
#28545 merged
May 7, 2025 -
[pallas:mosaic] Do not use *Op classes for creating MLIR ops unless necessary
#28572 merged
May 7, 2025 -
[Mosaic GPU] Implement a trivial pass-through transform inference for
memref.cast
#28556 merged
May 7, 2025 -
[Mosaic GPU] Extract duplicated code into a
_transforms_from_uses
function.#28555 merged
May 7, 2025 -
Temporarily roll back changes for new LLVM version
#28566 merged
May 7, 2025 -
Use
LoadedExecutableRef
instead ofstd::unique_ptr<LoadedExecutable>
#28558 merged
May 6, 2025 -
Fixed deadlock in MakeShardFn on static var assignment under free-threading
#28387 merged
May 6, 2025 -
Update sharded-computation.md
#28507 merged
May 6, 2025 -
Fix defaults overriding each other in tests/mutliprocess_gpu_test
#28533 merged
May 6, 2025 -
Avoid an unlucky seed for for some random categorical tests.
#28552 merged
May 6, 2025 -
[Mosaic GPU] Print instead of warning when skipping flash_attention test
#28485 merged
May 6, 2025 -
[pallas:mosaic] Handle more types in
ir_constant
#28548 merged
May 6, 2025 -
Enable command buffer support for buffer callbacks.
#28510 merged
May 6, 2025 -
Fix a x64 error in fused_attention_stablehlo
#28540 merged
May 6, 2025 -
[Mosaic-GPU] [3/3] Add support for communication primitives in MGPU lowering
#27684 merged
May 6, 2025
41 Pull requests opened by 10 people
-
Fix typo in persistent_compilation_cache.md
#28549 opened
May 6, 2025 -
Fixing broken fused_attention_stablehlo_test_gpu.
#28559 opened
May 6, 2025 -
fix box bug in scan transpose
#28561 opened
May 6, 2025 -
Trying to fix broken nvshmem build.
#28562 opened
May 6, 2025 -
Fixed OOM on pytest cuda.
#28563 opened
May 6, 2025 -
Expose profiler_data submodule from XLA to Jaxlib.
#28564 opened
May 6, 2025 -
[Mosaic] Fork transpose operation to TPU dialect.
#28569 opened
May 7, 2025 -
[pallas:mosaic_gpu] Removed the `GPU*` prefix from Mosaic GPU-specific types
#28578 opened
May 7, 2025 -
[xla:ifrt] Remove references to xla::ifrt::Compile that return LoadedExecutables.
#28581 opened
May 7, 2025 -
Consolidate initial/final style custom_vjp primitives into one
#28589 opened
May 7, 2025 -
Use the C4A machine type for Linux Arm64 builds
#28591 opened
May 7, 2025 -
[Mosaic GPU] Use PTX ISA version = min(ptxas, LLVM)
#28595 opened
May 7, 2025 -
Propagate use_shardy_partitioner to XlaCallModule op.
#28616 opened
May 8, 2025 -
[ifrt] Refactor away from deprecated constructors
#28627 opened
May 8, 2025 -
Ensure __jax_array__ support in binary ops
#28630 opened
May 8, 2025 -
Annotate lower triangular matrix
#28631 opened
May 8, 2025 -
Skip CSR matmat and matvec float tests on ROCm <6.4 (NaN issue with beta==0)
#28635 opened
May 9, 2025 -
Bumping oldest libtpu version
#28652 opened
May 9, 2025 -
Fixed @@pypi//nvidia_nvshmem_cu12 errors.
#28658 opened
May 9, 2025 -
JEP 28661: the __jax_array__ protocol
#28661 opened
May 9, 2025 -
[shard-map] start adding systematic smap tests
#28665 opened
May 10, 2025 -
[mosaic-gpu] add multicast ptr support to TMA with overlapped gemm and all reduce examples
#28679 opened
May 12, 2025 -
[Mosaic GPU] Add layout inference and lowering for `scf.WhileOp` and enable tests.
#28684 opened
May 12, 2025 -
[pallas:mosaic_gpu] arrive_expect_tx primitive.
#28686 opened
May 12, 2025 -
[CI] Add additional hardware to continuous non-rbe testing
#28688 opened
May 12, 2025 -
[pallas:mgpu] Arrive only once to the pipeline input barriers.
#28689 opened
May 12, 2025 -
Revert pytest: use importlib mode by default
#28690 opened
May 12, 2025 -
Deprecate the no-op custom_jvp_call_jaxpr_p import stub.
#28691 opened
May 12, 2025 -
Bump hypothesis from 6.102.4 to 6.131.15
#28693 opened
May 12, 2025 -
Bump fonttools from 4.51.0 to 4.58.0
#28694 opened
May 12, 2025 -
[Pallas] Allow f8 casting tests on TPUv5-.
#28698 opened
May 12, 2025 -
[pallas] Pulled `runtime_assert_enabled` from `pltpu` to `pl`
#28699 opened
May 12, 2025 -
Add use_raw_buffers which allows switching the implementation to
#28701 opened
May 12, 2025 -
Use DmaCopyChunk::Make because directly assigning the struct
#28703 opened
May 13, 2025 -
[Mosaic TPU][NFC] Consolidate `getIntConst`.
#28704 opened
May 13, 2025 -
Remove support for dimension level type.
#28705 opened
May 13, 2025 -
[pallas:mosaic_gpu] Added support for unrolling to `lax.fori_loop` lowering
#28708 opened
May 13, 2025
24 Issues closed by 8 people
-
Use batched library routines when available.
#28544 closed
May 13, 2025 -
jax.lax.platform_dependent doesn't stop Pallas from trying to lower for other backends?
#28594 closed
May 13, 2025 -
"Slow operation alarm" in `jax.scipy.signal.stft`
#28614 closed
May 13, 2025 -
*RuntimeError: bad conversion* when trying to register custom jax plugin
#27743 closed
May 12, 2025 -
Failure of cuDNN initialization at jax 0.4.12
#28550 closed
May 12, 2025 -
`jax.numpy.cumsum` brings different results with `numpy.cumsum`
#28646 closed
May 10, 2025 -
`jax.numpy.nancumsum` brings different results with `numpy.nancumsum`
#28669 closed
May 10, 2025 -
`jax.numpy.nansum` brings different results with `numpy.nansum`
#28670 closed
May 10, 2025 -
`lax` collective axis names not recognized with JIT sharding
#28666 closed
May 10, 2025 -
Full coverage for `__jax_array__` protocol
#24460 closed
May 9, 2025 -
inconsistent named shape representations
#8182 closed
May 9, 2025 -
`jax.numpy.linalg.tensorinv` brings different results with `numpy.linalg.tensorinv`
#28637 closed
May 9, 2025 -
`jax.numpy.unwrap` brings different results with `numpy.unwrap`
#28645 closed
May 9, 2025 -
`jax.numpy.fix` brings different results with `numpy.fix`
#28638 closed
May 9, 2025 -
AttributeError: module 'jaxlib.xla_client' has no attribute 'register_custom_type_id_handler'.
#26750 closed
May 8, 2025 -
`jax.numpy.linalg.matrix_power` brings different results with `numpy.linalg.matrix_power`
#28603 closed
May 8, 2025 -
`jax.numpy.put` brings different results with `numpy.put`
#28602 closed
May 8, 2025 -
[sharding-in-types] Bug in jnp.repeat
#28538 closed
May 8, 2025 -
`jax.numpy.linalg.eigvalsh` brings different results with `numpy.linalg.eigvalsh`
#28617 closed
May 8, 2025 -
`jax.numpy.linalg.matrix_rank` brings different results with `numpy.linalg.matrix_rank`
#28618 closed
May 8, 2025 -
`scipy.signal.istft` raises error with numpy window
#28571 closed
May 8, 2025 -
`jax.numpy.packbits` brings different results with `numpy.packbits`
#28583 closed
May 8, 2025 -
Deadlock in MakeShardFn function under free-threading
#28385 closed
May 6, 2025 -
[sharding-in-types] gather of non-shared dimensions is not supported
#28542 closed
May 6, 2025
11 Issues opened by 11 people
-
CUDA_ERROR_ILLEGAL_ADDRESS thrown by XlaRuntimeError
#28710 opened
May 13, 2025 -
Provide wheel for Windows ARM64
#28709 opened
May 13, 2025 -
segmentation fault (core dumped) while using gpu
#28692 opened
May 12, 2025 -
Failed build: CI - with Numpy/Scipy nightly wheels (nightly)
#28671 opened
May 10, 2025 -
Allow for passing slices of indices to `*_argnum` style arguments
#28667 opened
May 10, 2025 -
Nondeterministic behavior for pytree extensions if aux data does not have equality
#28659 opened
May 9, 2025 -
0.6 Installation issues and 0.5 segfaults
#28623 opened
May 8, 2025 -
jaxlib doesn't have a pyproject.toml
#28582 opened
May 7, 2025 -
Possible data race in PyOperation and ~PyOperation on cached Module
#28551 opened
May 6, 2025 -
`special.betainc`, `stdtr`, and `stdtrit` are very slow
#28547 opened
May 6, 2025
25 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
[Mosaic GPU] Check in WIP grouped GEMM
#26997 commented on
May 12, 2025 • 11 new comments -
Consolidate material on debugging NaNs.
#24989 commented on
May 7, 2025 • 6 new comments -
RFC: specify jit static args via Static annotation
#24705 commented on
May 7, 2025 • 2 new comments -
[Mosaic] Use native bf16 ops for tanh, exp and log on TPUv6+.
#28531 commented on
May 6, 2025 • 0 new comments -
Pass RULES_PYTHON_REPO_DEBUG value to the docker container env
#28530 commented on
May 6, 2025 • 0 new comments -
Major deps update:
#28497 commented on
May 11, 2025 • 0 new comments -
[JAX] Make fully replicated sharding to avoid materializing the same host buffers
#28493 commented on
May 13, 2025 • 0 new comments -
Rename backend.compile to backend.compile_and_load.
#28451 commented on
May 8, 2025 • 0 new comments -
[JAX] Add python3.14 wheel test configs
#28440 commented on
May 6, 2025 • 0 new comments -
Deprecate parsing of __jax_array__ during abstractification.
#28355 commented on
May 7, 2025 • 0 new comments -
Add docs for multi-process run in Kubernetes
#28317 commented on
May 12, 2025 • 0 new comments -
Fix overloaded type signature for jax.numpy.where.
#28314 commented on
May 7, 2025 • 0 new comments -
Add support for output and input memory space colors in tpu custom calls via CustomCallConfig.
#28290 commented on
May 12, 2025 • 0 new comments -
Clarify that jnp.clip gives preference to maximum values where bounds are incongruent
#28275 commented on
May 7, 2025 • 0 new comments -
Add cudnn paged attention support in JAX cuDNN SDPA API
#28102 commented on
May 13, 2025 • 0 new comments -
#sdy Properly handle token types in JAX and `ManualComputationOp`.
#27897 commented on
May 7, 2025 • 0 new comments -
Move the section on jitting methods from FAQ to Sharp Bits.
#25273 commented on
May 8, 2025 • 0 new comments -
Different checkpoint behaviors in different version of JAX and CUDA
#27748 commented on
May 13, 2025 • 0 new comments -
jax.random.binomial returns float
#28457 commented on
May 11, 2025 • 0 new comments -
Add masked arrays
#8979 commented on
May 10, 2025 • 0 new comments -
Notice a different between jax.image.resize and F.interpolate when using "bicubic"
#15768 commented on
May 10, 2025 • 0 new comments -
Nested `jax.jit` hangs on second call when `jax.Array` stored in static pytree metadata
#28311 commented on
May 9, 2025 • 0 new comments -
Results do not match the reference. This is likely a bug/unexpected loss of precision
#24909 commented on
May 7, 2025 • 0 new comments -
Add support nogil mode in JAX for Python 3.13
#23073 commented on
May 6, 2025 • 0 new comments -
State of jax.scipy.special functions: tested by evaluation or autograd, incorrectness and missing functionality
#27088 commented on
May 6, 2025 • 0 new comments