8000 [proposal] GSoC Project 8: JAX Runtime for V2 by mahdikhashan · Pull Request #2643 · kubeflow/trainer · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[proposal] GSoC Project 8: JAX Runtime for V2 #2643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

mahdikhashan
Copy link
Member

What this PR does / why we need it:

Which issue(s) this PR fixes < 8000 em>(optional, in Fixes #<issue number>, #<issue number>, ... format, will close the issue(s) when PR gets merged):
Fixes #

Checklist:

  • Docs included if any changes are user facing

@coveralls
Copy link
coveralls commented May 10, 2025

Pull Request Test Coverage Report for Build 15828890426

Details

  • 0 of 0 changed or added relevant lines in 0 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 29.19%

Totals Coverage Status
Change from base Build 15579727901: 0.0%
Covered Lines: 897
Relevant Lines: 3073

💛 - Coveralls

@google-oss-prow google-oss-prow bot added size/L and removed size/XS labels May 10, 2025
@mahdikhashan mahdikhashan changed the title [proposal] GSoC Project 8: Jax Runtime for V2 [proposal] GSoC Project 8: JAX Runtime for V2 May 10, 2025
@mahdikhashan mahdikhashan changed the title [proposal] GSoC Project 8: JAX Runtime for V2 (draft)[proposal] GSoC Project 8: JAX Runtime for V2 May 28, 2025
@mahdikhashan mahdikhashan force-pushed the gsoc-2442-jax-runtime-proposal branch from c39fc67 to 271a5d1 Compare May 28, 2025 16:53
@mahdikhashan mahdikhashan marked this pull request as draft May 28, 2025 17:01
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign gaocegege for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot added size/L and removed size/M labels Jun 20, 2025
@mahdikhashan mahdikhashan changed the title (draft)[proposal] GSoC Project 8: JAX Runtime for V2 [proposal] GSoC Project 8: JAX Runtime for V2 Jun 20, 2025
@mahdikhashan mahdikhashan marked this pull request as ready for review June 20, 2025 17:51
add title

remove title comment

add summary, goal and non-goal

Signed-off-by: mahdikhashan <mahdikhashan1@gmail.com>

clean up

update motivation and story 1

update motivation

add table of content, clean up and initial draft date

remove extra comment

update story 1

add second user story

add 3rd story

wip

rename proposal folder

explicitly mention non-goal of implementation for TPUs

add `JAX_CPU_COLLECTIVES_IMPLEMENTATION` to the env table

update story 3

update description for JAX_CPU_COLLECTIVES_IMPLEMENTATION

update goal and non-goal

mention multi-controller

improve the motivation

wip: design details

update title

remove redundant text

improve kep

improve

improve

update

update resource

update

update
@mahdikhashan mahdikhashan force-pushed the gsoc-2442-jax-runtime-proposal branch from b137639 to cadaf06 Compare June 20, 2025 17:53
@mahdikhashan
Copy link
Member Author

@Electronic-Waste @andreyvelich please take a look when you have time.

@mahdikhashan
Copy link
Member Author
mahdikhashan commented Jun 20, 2025

regarding data loading I mentioned in our sync meetings, after going over jax document, as you said, it will be handled within objective function by the user, so its sharding of data in memory and no need for further mechanism to handle permanent data (PVC, PV). however user can take care of it anytime with a patch.

cc @Electronic-Waste

Copy link
Member
@Electronic-Waste Electronic-Waste left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mahdikhashan Thanks for this great work. I've left my initial comments for you.

PTAL when you have time.

/cc @kubeflow/wg-training-leads @astefanutti @franciscojavierarceo

< 8000 details open="open" data-resolved="false" data-target="details-collapsible.detailsElement details-toggle.detailsTarget" data-view-component="true" class="review-thread-component js-comment-container js-resolvable-timeline-thread-container Details-element details-reset mb-3 border rounded-2">

## Summary

This proposal implements key components of KEP-2170, introducing the Kubeflow Training V2 API. Specifically, it focuses on creating the TrainingRuntime and ClusterTrainingRuntime for the JAX framework, built upon the Kubernetes JobSet API. These runtimes will serve as blueprints for model training (including LLMs) within cloud-native ML pipelines. This abstraction allows Data Scientists and MLOps Engineers to easily reuse standardized runtimes and launch training jobs, particularly via the SDK, without needing deep knowledge of underlying Kubernetes complexities.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can write the summary section in these formats:

  1. "This document outlines a proposal to support Jax Runtime in Kubeflow Trainer V2"
  2. Then describe its necessity and benefits in a few sentences, like "Built upon Kubernetes JobSet API, the Jax runtime ...."
  3. Other things you want to write
  4. Finally, mention the main framework you use, like "Thanks to the Kubeflow Trainer Pipeline Framework, we can seamlessly support Jax Runtime in Kubeflow Trainer as a runtime plugin"

## Motivation

JAX is a powerful Computation library created by Google, It is widely used in machine learning research and ranks as the third most wide used deep learning framework. JAX is not only a deep learning framework but suggests its potential in
differential programming, large-scale physics simulations and many more. These usecases added on top of the new Runtime API for distributed training or calculation of objectives enables new users on top of kubeflow trainer, like distributed simulation or training of LLM prototypes developed with JAX, like vast models from Google Deep mind.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can list the benefits with:

1.
2.
3.

And It might be more clear.


- No TPU support (duo to lack of available TPU testing infrastructure)
- No GPU testing, tests will use CPUs
- No Custom MLPolicy, since using OpenMPI, it can handle required parameters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please show more reference materials for us? For example, other design plans against OpenMPI.

Or you could list these different design plans and show their pros/cons, so that we can compare them and choose the best one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the design detail section, i have compared two existing options for backend, OpenMPI and Gloo.

####  Choosing OpenMPI over Gloo

- OpenMPI is **10–20× faster than Gloo** for distributed JAX training on CPUs and GPUs, thanks to optimized communication and bandwidth usage.  
- Better multi-node scaling with lower latency and support for high-speed interconnects like InfiniBand.  
- Compatible with existing **MPI runtime in Kubeflow Trainer v2**, making deployment easier.

#### Defining JAX Processes with MLPolicy

- The number of JAX processes can be defined using the `numNodes` field within the `mlPolicy` section of the ClusterTrainingRuntime configuration. This allows for specifying how many JAX processes/controller run in the distributed setup.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, maybe you can provide detailed comparison for these two backend, OpenMPI and Gloo.

Like:

### Communication Backend

#### OpenMPI

Pros & Cons

#### Gloo

Pros & Cons

#### <Other backend>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, now i see, I'll do so.

Comment on lines +45 to +57
### User Stories

#### Story 1

As a MLOps Engineer or Platform Engineer, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams.

#### Story 2

As a Data Scientist, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task.

#### Story 3

As a Research Scientist, I want to train prototype of my new LLM model coded with JAX on a distributed training setup on my company Kubernetes cluster, Kubeflow Trainer V2 with JAX ClusterTrainingRuntime will enable this for me.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better provide detailed yaml files or python codes for these user stories.

Like: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2170-kubeflow-trainer-v2#example-of-trainjob

- No mechanism for finding processes, since with OpenMPI backend JAX can find all processes automatically
- Complex end-to-end examples demonstrating the runtimes (focus is on the runtime implementation itself; examples may require specific infrastructure)

## Proposal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYR, we can describe the high-level design ideas in the Proposal section.

@google-oss-prow google-oss-prow bot requested review from franciscojavierarceo and a team June 23, 2025 10:14
mahdikhashan and others added 2 commits June 23, 2025 17:45
Co-authored-by: Shao Wang <2690692950@qq.com>
Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
Co-authored-by: Shao Wang <2690692950@qq.com>
Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
Co-authored-by: Shao Wang <2690692950@qq.com>
Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0