8000 Additional bindings logic by shabani1 · Pull Request #12 · lexy-ai/lexy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Additional bindings logic #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions lexy/api/endpoints/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TransformerIndexBindingCreate,
TransformerIndexBindingUpdate
)
from lexy.core.events import process_new_binding


router = APIRouter()
Expand All @@ -28,13 +29,18 @@ async def get_bindings(session: AsyncSession = Depends(get_session)) -> list[Tra
status_code=status.HTTP_201_CREATED,
name="add_binding",
description="Create a new binding")
async def add_binding(binding: TransformerIndexBindingCreate,
session: AsyncSession = Depends(get_session)) -> TransformerIndexBinding:
async def add_binding(binding: TransformerIndexBindingCreate, session: AsyncSession = Depends(get_session)) -> dict:
binding = TransformerIndexBinding(**binding.dict())
session.add(binding)
await session.commit()
await session.refresh(binding)
return binding
processed_binding, tasks = await process_new_binding(binding)
# now commit the binding again and refresh it - status should be updated
session.add(processed_binding)
await session.commit()
await session.refresh(processed_binding)
response = {"binding": processed_binding, "tasks": tasks}
return response


@router.get("/bindings/{binding_id}",
Expand Down
18 changes: 18 additions & 0 deletions lexy/core/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
- https://docs.celeryq.dev/en/stable/userguide/configuration.html#override-backends
"""

from typing import Any
from uuid import UUID

from celery import current_app as celery, Task
from celery.utils.log import get_logger, get_task_logger
from sqlalchemy.orm import sessionmaker
Expand Down Expand Up @@ -68,6 +71,7 @@ def on_failure(self, exc, task_id, args, kwargs, einfo):

@celery.task(base=DatabaseTask, bind=True, name="lexy.db.save_result_to_index")
def save_result_to_index(self, res, document_id, text, index_id):
""" Save the result of a transformer to an index. """
task_logger.debug(f"Starting DB task 'save_result_to_index' for index {index_id} "
f"with task ID {self.request.id} and parent task ID {self.request.parent_id}")
# noinspection PyPep8Naming
Expand All @@ -77,3 +81,17 @@ def save_result_to_index(self, res, document_id, text, index_id):
IndexClass(document_id=document_id, embedding=res.tolist(), text=text, task_id=self.request.parent_id)
)
self.db.commit()


@celery.task(base=DatabaseTask, bind=True, name="lexy.db.save_r 8000 ecords_to_index")
def save_records_to_index(self, records: list[dict[str, Any]], document_id: UUID, text: str, index_id: str):
""" Save the output of a transformer to an index. """
task_logger.debug(f"Starting DB task 'save_records_to_index' for index {index_id} "
f"with task ID {self.request.id} and parent task ID {self.request.parent_id}")
# noinspection PyPep8Naming
IndexClass = index_manager.index_models[index_id]
for record in records:
self.db.add(
IndexClass(document_id=document_id, text=text, task_id=self.request.parent_id, **record)
)
self.db.commit()
36 changes: 23 additions & 13 deletions lexy/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from lexy.models import Document, TransformerIndexBinding
from lexy.core.celery_tasks import save_result_to_index
from lexy.core.celery_tasks import save_result_to_index, save_records_to_index
from lexy.indexes import index_manager


Expand Down Expand Up @@ -60,6 +60,7 @@ async def generate_tasks_for_document(doc: Document) -> list[dict]:
return tasks


# TODO: this should move to lexy.indexes.IndexManager
def create_new_index_table(index_id: str):
""" Create a new index model and its associated table (requires that the index row already exists in the database).

Expand All @@ -80,10 +81,10 @@ def create_new_index_table(index_id: str):
index_model = index_manager.create_index_model(index)
if index_manager.table_exists(index.index_table_name):
logger.warning(f"Index table '{index.index_table_name}' already exists")
index_model.metadata.create_all(index_manager.sync_engine)
index_model.metadata.create_all(index_manager.db.bind.engine)


def process_new_binding(binding: TransformerIndexBinding, create_index_table: bool = False) \
async def process_new_binding(binding: TransformerIndexBinding, create_index_table: bool = False) \
-> tuple[TransformerIndexBinding, list[dict]]:
""" Process a new transformer index binding.

Expand All @@ -102,6 +103,19 @@ def process_new_binding(binding: TransformerIndexBinding, create_index_table: bo
"""
logger.info(f"Processing new transformer index binding {binding}")

# check if binding has a valid transformer
if binding.transformer is None:
raise Exception(f"Binding {binding} does not have a transformer associated with it")
if binding.transformer.path is None:
raise Exception(f"Binding {binding} does not have a valid transformer path associated with it")

# import the transformer function
# TODO: just import the function from celery?
tfr_mod_name, tfr_func_name = binding.transformer.path.rsplit('.', 1)
tfr_module = importlib.import_module(tfr_mod_name)
transformer_func = getattr(tfr_module, tfr_func_name)

# check if binding has a valid index
if binding.index is None:
logger.info(f"Binding {binding} does not have an index associated with it")
# create index table
Expand All @@ -118,12 +132,6 @@ def process_new_binding(binding: TransformerIndexBinding, create_index_table: bo
else:
documents = binding.collection.documents

# import the transformer function
# TODO: just import the function from celery?
tfr_mod_name, tfr_func_name = binding.transformer.path.rsplit('.', 1)
tfr_module = importlib.import_module(tfr_mod_name)
transformer_func = getattr(tfr_module, tfr_func_name)

# initiate list of tasks
tasks = []

Expand All @@ -133,14 +141,16 @@ def process_new_binding(binding: TransformerIndexBinding, create_index_table: bo
task = transformer_func.apply_async(
args=[doc.content],
kwargs=binding.transformer_params,
link=save_result_to_index.s(document_id=doc.document_id,
text=doc.content,
index_id=binding.index_id)
link=save_records_to_index.s(document_id=doc.document_id,
text=doc.content,
index_id=binding.index_id)
)
tasks.append({"task_id": task.id, "document_id": doc.document_id})

# switch binding status to 'on'
index_manager.switch_binding_status(binding, 'on')
prev_status = binding.status
binding.status = 'on'
logger.info(f"Set status for binding {binding}: from '{prev_status}' to 'on'")

logger.info(f"Created {len(tasks)} tasks for binding {binding}: "
f"[{', '.join([t['task_id'] for t in tasks])}]")
Expand Down
1 change: 1 addition & 0 deletions lexy/db/init_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def add_sample_data_to_db(session=db):
logger.info("Transformer data already exists")
else:
session.add(models.Transformer(**sample_data["transformer_1"]))
session.add(models.Transformer(**sample_data["transformer_2"]))
session.commit()
if session.query(models.Index).count() > 0:
logger.info("Index data already exists")
Expand Down
5 changes: 5 additions & 0 deletions lexy/db/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
"path": "lexy.transformers.embeddings.text_embeddings",
"description": "Text embeddings using Hugging Face model 'sentence-transformers/all-MiniLM-L6-v2'"
},
"transformer_2": {
"transformer_id": "text.counter.word_counter",
"path": "lexy.transformers.counter.word_counter",
"description": "Returns count of words and the longest word"
},
"index_1": {
"index_id": "default_text_embeddings",
"description": "Text embeddings for default collection",
Expand Down
1 change: 1 addition & 0 deletions lexy/db/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ async def add_sample_data_to_db(session: AsyncSession):
session.add(Document(**sample_data["document_6"]))
await session.commit()
session.add(Transformer(**sample_data["transformer_1"]))
session.add(Transformer(**sample_data["transformer_2"]))
await session.commit()
session.add(Index(**sample_data["index_1"]))
await session.commit()
Expand Down
2 changes: 1 addition & 1 deletion lexy/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_field_definitions(index_fields: dict) -> dict:
fd_info = Field(sa_column=Column(ARRAY(REAL)))
else:
fd_type = LEXY_INDEX_FIELD_TYPES.get(fv['type'])
if fv['optional'] is True:
if 'optional' in fv and fv['optional'] is True:
fd_type = Optional[fd_type]

if fv['type'] in ['dict', 'object', 'list', 'array']:
Expand Down
3 changes: 2 additions & 1 deletion lexy/models/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class TransformerIndexBinding(TransformerIndexBindingBase, table=True):
sa_column=Column(DateTime(timezone=True), server_default=func.now(), >
)
status: str = Field(default=BindingStatus.PENDING, nullable=False)
collection: Collection = Relationship(back_populates="transformer_index_bindings")
collection: Collection = Relationship(back_populates="transformer_index_bindings",
sa_relationship_kwargs={'lazy': 'selectin'})
transformer: Transformer = Relationship(back_populates="index_bindings",
sa_relationship_kwargs={'lazy': 'selectin'})
index: Index = Relationship(back_populates="transformer_bindings", sa_relationship_kwargs={'lazy': 'selectin'})
Expand Down
2 changes: 1 addition & 1 deletion lexy/models/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Collection(CollectionBase, table=True):
nullable=False,
sa_column=Column(DateTime(timezone=True), server_default=func.now(), >
)
documents: list["Document"] = Relationship(back_populates="collection")
documents: list["Document"] = Relationship(back_populates="collection", sa_relationship_kwargs={'lazy': 'subquery'})
transformer_index_bindings: list["TransformerIndexBinding"] = \
Relationship(back_populates="collection", sa_relationship_kwargs={'lazy': 'selectin'})

Expand Down
31 changes: 0 additions & 31 deletions lexy/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sqlmodel import Field, SQLModel, Relationship

from lexy.models.collection import Collection
# from lexy.core.celery_tasks import save_result_to_index


class DocumentBase(SQLModel):
Expand Down Expand Up @@ -36,36 +35,6 @@ class Document(DocumentBase, table=True):
collection: Collection = Relationship(back_populates="documents", sa_relationship_kwargs={'lazy': 'selectin'})
embeddings: list["Embedding"] = Relationship(back_populates="document")

# def apply_async_bindings(self):
# # initiate list of tasks
# tasks = []
# # loop through transformer bindings for this document
# for binding in self.collection.transformer_index_bindings:
# # check if binding is enabled
# if binding.status != 'on':
# print(
# f"Skipping transformer index binding {binding} because it is not enabled (status: {binding.status})")
# continue
# # check if document matches binding filters
# if binding.filters and not all(f(self) for f in binding.filters):
# print(f"Skipping transformer index binding {binding} because document does not match filters")
# continue
# # import the transformer function
# # TODO: just import the function from celery?
# tfr_mod_name, tfr_func_name = binding.transformer.path.rsplit('.', 1)
# tfr_module = importlib.import_module(tfr_mod_name)
# transformer_func = getattr(tfr_module, tfr_func_name)
# # generate the task
# task = transformer_func.apply_async(
# args=[self.content],
# kwargs=binding.transformer_params,
# link=save_result_to_index.s(document_id=self.document_id,
# text=self.content,
# index_id=binding.index_id)
# )
# tasks.append({"task_id": task.id, "document_id": self.document_id})
# return tasks


class DocumentCreate(DocumentBase):
pass
Expand Down
12 changes: 12 additions & 0 deletions lexy/transformers/counter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from celery import shared_task


Expand All @@ -24,3 +26,13 @@ def count_words(text: str) -> dict[str, int]:
else:
counts[word] = 1
return counts


@shared_task(name="lexy.transformers.counter.word_counter")
def word_counter(text: str) -> list[dict[str, Any]]:
""" Testing a transformer. """
words = text.split()
word_count = len(words)
longest_word = max(words, key=len)
return [{"word_count": word_count, "longest_word": longest_word}]

20 changes: 17 additions & 3 deletions lexy/transformers/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Tuple

from celery import shared_task

import torch
from numpy import ndarray
from sentence_transformers import SentenceTransformer

from torch import Tensor

torch.set_num_threads(1)
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
Expand All @@ -24,6 +23,21 @@ def text_embeddings(sentences: list[str]) -> torch.Tensor:
return model.encode(sentences, batch_size=len(sentences))


@shared_task(name="lexy.transformers.embeddings.text_embeddings_transformer")
def text_embeddings_transformer(sentences: list[str]) -> list[dict[str, list[Tensor] | ndarray | Tensor]]:
""" Embed sentences using SentenceTransformer.

Args:
sentences: list of sentences to embed

Returns:
torch.Tensor: embeddings

"""
res = {'embedding': model.encode(sentences, batch_size=len(sentences))}
return [res]


@shared_task(name="lexy.transformers.embeddings.get_chunks")
def get_chunks(text, chunk_size=384) -> list[str]:
"""
Expand Down
0