-
Notifications
You must be signed in to change notification settings - Fork 785
[proposal] GSoC Project 8: JAX Runtime for V2 #2643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[proposal] GSoC Project 8: JAX Runtime for V2 #2643
Conversation
Pull Request Test Coverage Report for Build 15828890426Details
💛 - Coveralls |
c39fc67
to
271a5d1
Compare
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here.
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
add title remove title comment add summary, goal and non-goal Signed-off-by: mahdikhashan <mahdikhashan1@gmail.com> clean up update motivation and story 1 update motivation add table of content, clean up and initial draft date remove extra comment update story 1 add second user story add 3rd story wip rename proposal folder explicitly mention non-goal of implementation for TPUs add `JAX_CPU_COLLECTIVES_IMPLEMENTATION` to the env table update story 3 update description for JAX_CPU_COLLECTIVES_IMPLEMENTATION update goal and non-goal mention multi-controller improve the motivation wip: design details update title remove redundant text improve kep improve improve update update resource update update
b137639
to
cadaf06
Compare
@Electronic-Waste @andreyvelich please take a look when you have time. |
regarding data loading I mentioned in our sync meetings, after going over jax document, as you said, it will be handled within objective function by the user, so its sharding of data in memory and no need for further mechanism to handle permanent data (PVC, PV). however user can take care of it anytime with a patch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mahdikhashan Thanks for this great work. I've left my initial comments for you.
PTAL when you have time.
/cc @kubeflow/wg-training-leads @astefanutti @franciscojavierarceo
|
||
## Summary | ||
|
||
This proposal implements key components of KEP-2170, introducing the Kubeflow Training V2 API. Specifically, it focuses on creating the TrainingRuntime and ClusterTrainingRuntime for the JAX framework, built upon the Kubernetes JobSet API. These runtimes will serve as blueprints for model training (including LLMs) within cloud-native ML pipelines. This abstraction allows Data Scientists and MLOps Engineers to easily reuse standardized runtimes and launch training jobs, particularly via the SDK, without needing deep knowledge of underlying Kubernetes complexities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can write the summary section in these formats:
- "This document outlines a proposal to support Jax Runtime in Kubeflow Trainer V2"
- Then describe its necessity and benefits in a few sentences, like "Built upon Kubernetes JobSet API, the Jax runtime ...."
- Other things you want to write
- Finally, mention the main framework you use, like "Thanks to the Kubeflow Trainer Pipeline Framework, we can seamlessly support Jax Runtime in Kubeflow Trainer as a runtime plugin"
## Motivation | ||
|
||
JAX is a powerful Computation library created by Google, It is widely used in machine learning research and ranks as the third most wide used deep learning framework. JAX is not only a deep learning framework but suggests its potential in | ||
differential programming, large-scale physics simulations and many more. These usecases added on top of the new Runtime API for distributed training or calculation of objectives enables new users on top of kubeflow trainer, like distributed simulation or training of LLM prototypes developed with JAX, like vast models from Google Deep mind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can list the benefits with:
1.
2.
3.
And It might be more clear.
|
||
- No TPU support (duo to lack of available TPU testing infrastructure) | ||
- No GPU testing, tests will use CPUs | ||
- No Custom MLPolicy, since using OpenMPI, it can handle required parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please show more reference materials for us? For example, other design plans against OpenMPI.
Or you could list these different design plans and show their pros/cons, so that we can compare them and choose the best one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the design detail section, i have compared two existing options for backend, OpenMPI and Gloo.
#### Choosing OpenMPI over Gloo
- OpenMPI is **10–20× faster than Gloo** for distributed JAX training on CPUs and GPUs, thanks to optimized communication and bandwidth usage.
- Better multi-node scaling with lower latency and support for high-speed interconnects like InfiniBand.
- Compatible with existing **MPI runtime in Kubeflow Trainer v2**, making deployment easier.
#### Defining JAX Processes with MLPolicy
- The number of JAX processes can be defined using the `numNodes` field within the `mlPolicy` section of the ClusterTrainingRuntime configuration. This allows for specifying how many JAX processes/controller run in the distributed setup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, maybe you can provide detailed comparison for these two backend, OpenMPI and Gloo.
Like:
### Communication Backend
#### OpenMPI
Pros & Cons
#### Gloo
Pros & Cons
#### <Other backend>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, now i see, I'll do so.
### User Stories | ||
|
||
#### Story 1 | ||
|
||
As a MLOps Engineer or Platform Engineer, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams. | ||
|
||
#### Story 2 | ||
|
||
As a Data Scientist, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task. | ||
|
||
#### Story 3 | ||
|
||
As a Research Scientist, I want to train prototype of my new LLM model coded with JAX on a distributed training setup on my company Kubernetes cluster, Kubeflow Trainer V2 with JAX ClusterTrainingRuntime will enable this for me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better provide detailed yaml files or python codes for these user stories.
- No mechanism for finding processes, since with OpenMPI backend JAX can find all processes automatically | ||
- Complex end-to-end examples demonstrating the runtimes (focus is on the runtime implementation itself; examples may require specific infrastructure) | ||
|
||
## Proposal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYR, we can describe the high-level design ideas in the Proposal section.
Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com>
What this PR does / why we need it:
Which issue(s) this PR fixes < 8000 em>(optional, in
Fixes #<issue number>, #<issue number>, ...
format, will close the issue(s) when PR gets merged):Fixes #
Checklist: