8000 Add Llama model by amkatrutsa · Pull Request #111 · nntile/nntile · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add Llama model #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 22, 2024
Merged

Add Llama model #111

merged 26 commits into from
Jul 22, 2024

Conversation

amkatrutsa
Copy link
Contributor

No description provided.

@amkatrutsa amkatrutsa requested a review from Muxas July 20, 2024 19:47
@amkatrutsa amkatrutsa marked this pull request as ready for review July 21, 2024 08:36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file seems to be still in progress. What are the next steps to be done?

redux=0)
self.x.wont_use()
self.y.grad.wont_use()
self.w.grad.wont_use()

def to_torch(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #112 to add tests for new functions

self.y.grad.wont_use()
self.res.grad.wont_use()

def unregister(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unregister is unnecessary. Tensor self.res is neither a parameter of a layer nor a temporary tensor, but an activation. Activations are cleared out by base model, not by layer itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this unregister is just copied from add layer, it shall be removed from add also.

@@ -171,3 +175,30 @@ def backward_async(self):
self.tmp_y_value.invalidate_submit()
# dX can offloade from GPU
self.x.grad.wont_use()

@staticmethod
def from_torch(torch_rmsnorm, x: TensorMoments,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add #113 to test new functionality

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this file tests LlamaMLP, then it shall be deleted because test_llama_mlp.py is already implemented.

return torch_model, nntile_model, x_torch, pos_ids, y_grad_torch


@pytest.mark.parametrize("params", TEST_PARAMS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added #114 to simplify testing parameters

import nntile
from nntile.model.llama_config import LlamaConfigNNTile
from nntile.model.llama_decoder import LlamaDecoder as LlamaDecoder_nntile
# from nntile.model.llama import LlamaConfigNNTile
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm this line

return torch_layer, nntile_layer, x_torch, y_grad_torch, pos_ids, mask


@pytest.mark.parametrize("params", TEST_PARAMS)
Copy link
Member

Choose a reason E864 for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#115 is to simplify parameters of this test

return torch_layer, nntile_layer, x_torch, y_grad_torch


@pytest.mark.parametrize("params", TEST_PARAMS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#116 is to simplify parameters of this test

from typing import Dict


class LlamaConfigNNTile(Dict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use dict in favor of dataclass does not make a lot of sense if all attributes are known a priori.

from dataclasses import asdict, dataclass

@dataclass
class Config: ...

config = Config(...)
value = asdict(config)['key']['subkey']['subsubkey']

): # -> Self: does not work with Python 3.10
layer, _ = __class__.generate_simple(
layer, next_tag = __class__.generate_simple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly use case of classmethod.

class Attention:
    @classmethod
    def from_torch(cls, ...):
        layer, next_tag = cls.generate_simple(...)
        ...

): # -> Self: does not work with Python 3.10
layer, _ = __class__.generate_simple(
layer, next_tag = __class__.generate_simple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly use case of classmethod.

class Attention:
    @classmethod
    def from_torch(cls, ...):
        layer, next_tag = cls.generate_simple(...)
        ...

): # -> Self: does not work with Python 3.10
layer, _ = __class__.generate_simple(
layer, next_tag = __class__.generate_simple(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly use case of classmethod.

class Attention:
    @classmethod
    def from_torch(cls, ...):
        layer, next_tag = cls.generate_simple(...)
        ...

@amkatrutsa amkatrutsa merged commit 44209c2 into main Jul 22, 2024
5 checks passed
@amkatrutsa amkatrutsa deleted the amkatrutsa/llama branch July 22, 2024 12:30
@Muxas Muxas mentioned this pull request Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0