-
Notifications
You must be signed in to change notification settings - Fork 10
Add Llama model #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Llama model #111
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file seems to be still in progress. What are the next steps to be done?
redux=0) | ||
self.x.wont_use() | ||
self.y.grad.wont_use() | ||
self.w.grad.wont_use() | ||
|
||
def to_torch(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added #112 to add tests for new functions
wrappers/python/nntile/layer/prod.py
Outdated
self.y.grad.wont_use() | ||
self.res.grad.wont_use() | ||
|
||
def unregister(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unregister is unnecessary. Tensor self.res
is neither a parameter of a layer nor a temporary tensor, but an activation. Activations are cleared out by base model, not by layer itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this unregister is just copied from add
layer, it shall be removed from add
also.
@@ -171,3 +175,30 @@ def backward_async(self): | |||
self.tmp_y_value.invalidate_submit() | |||
# dX can offloade from GPU | |||
self.x.grad.wont_use() | |||
|
|||
@staticmethod | |||
def from_torch(torch_rmsnorm, x: TensorMoments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add #113 to test new functionality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this file tests LlamaMLP, then it shall be deleted because test_llama_mlp.py
is already implemented.
return torch_model, nntile_model, x_torch, pos_ids, y_grad_torch | ||
|
||
|
||
@pytest.mark.parametrize("params", TEST_PARAMS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added #114 to simplify testing parameters
import nntile | ||
from nntile.model.llama_config import LlamaConfigNNTile | ||
from nntile.model.llama_decoder import LlamaDecoder as LlamaDecoder_nntile | ||
# from nntile.model.llama import LlamaConfigNNTile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm this line
return torch_layer, nntile_layer, x_torch, y_grad_torch, pos_ids, mask | ||
|
||
|
||
@pytest.mark.parametrize("params", TEST_PARAMS) |
There was a problem hiding this comment.
Choose a reason E864 for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#115 is to simplify parameters of this test
return torch_layer, nntile_layer, x_torch, y_grad_torch | ||
|
||
|
||
@pytest.mark.parametrize("params", TEST_PARAMS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#116 is to simplify parameters of this test
from typing import Dict | ||
|
||
|
||
class LlamaConfigNNTile(Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use dict
in favor of dataclass
does not make a lot of sense if all attributes are known a priori.
from dataclasses import asdict, dataclass
@dataclass
class Config: ...
config = Config(...)
value = asdict(config)['key']['subkey']['subsubkey']
): # -> Self: does not work with Python 3.10 | ||
layer, _ = __class__.generate_simple( | ||
layer, next_tag = __class__.generate_simple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly use case of classmethod
.
class Attention:
@classmethod
def from_torch(cls, ...):
layer, next_tag = cls.generate_simple(...)
...
): # -> Self: does not work with Python 3.10 | ||
layer, _ = __class__.generate_simple( | ||
layer, next_tag = __class__.generate_simple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly use case of classmethod
.
class Attention:
@classmethod
def from_torch(cls, ...):
layer, next_tag = cls.generate_simple(...)
...
): # -> Self: does not work with Python 3.10 | ||
layer, _ = __class__.generate_simple( | ||
layer, next_tag = __class__.generate_simple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly use case of classmethod
.
class Attention:
@classmethod
def from_torch(cls, ...):
layer, next_tag = cls.generate_simple(...)
...
No description provided.