-
Notifications
You must be signed in to change notification settings - Fork 10
create test script for lion optimizer #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Parallel Training of LLMs Mid Term Project for Okechukwu Okeke and Joshua Udobang |
@@ -15,9 +15,8 @@ | |||
#pragma once | |||
|
|||
#include <nntile/kernel/accumulate_maxsumexp.hh> | |||
#include <nntile/kernel/add_slice_inplace.hh> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are certain include deleted?
* NNTile is software framework for fast training of big neural networks on | ||
* distributed-memory heterogeneous systems based on StarPU runtime system. | ||
* | ||
* @file include/nntile/kernel/adam_step.hh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lion
#include <nntile/kernel/lion_step/cuda.hh> | ||
#endif // NNTILE_USE_CUDA | ||
|
||
//! @namespace nntile::kernel::adam_step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lion
@@ -31,7 +30,6 @@ | |||
#include <nntile/kernel/hypot.hh> | |||
#include <nntile/kernel/normalize.hh> | |||
#include <nntile/kernel/prod.hh> | |||
#include <nntile/kernel/prod_inplace.hh> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt that nntile compiles without these deleted includes
template<typename T> | ||
void cpu(Index num_iter, Index num_elems, Scalar beta_1, Scalar beta_2, | ||
Scalar eps, Scalar lr, Scalar weight_decay, const T *grad, | ||
T *first_moment, T *second_moment, T *p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no second moment in lion
// Apply weight decay if needed | ||
if (wd != Y(0)) | ||
{ | ||
grad_val += wd * p_val; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight decay is applied incorrectly here
const Y beta_2 = static_cast<Y>(beta_2_); | ||
const Y lambda_val = static_cast<Y>(lambda_); | ||
const Y lr = static_cast<Y>(lr_); | ||
const Y wd = static_cast<Y>(weight_decay_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter lambda is the weight decay, no need to pass it twice. I see, that you did not understand logic of weight decay correctly, that is why you have 2 of them.
""" | ||
# Optionally apply weight decay (L2 regularization) by modifying the gradient. | ||
if weight_decay != 0: | ||
grad = grad + weight_decay * p |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this weight decay is incorrect
@@ -76,14 +76,6 @@ def gemm_async( | |||
core_tensor.gemm_async_fp32_fast_tf32( | |||
alpha, trans_A, A, trans_B, B, beta, C, ndim, batch_ndim, redux | |||
) | |||
elif type(A) is core_tensor.Tensor_fp32_fast_fp16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of deleted lines in this file mean only one thing -- you messed up with versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of deleted lines mean that you have failed with git repository revisions.
No description provided.