8000 Added the janky version of openai embeddings by shabani1 · Pull Request #45 · lexy-ai/lexy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added the janky version of openai embeddings #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ You'll also need to specify an S3 bucket for file storage (for which your AWS cr
You can do so by adding `S3_BUCKET=<name-of-your-S3-bucket>` to your `.env` file, or by updating the value of
`s3_bucket` in `lexy/core/config.py`.

### Using OpenAI transformers

To use OpenAI embeddings in Lexy, you'll need to set the `OPENAI_API_KEY` environment variable. You can do so by adding
the following to your `.env` file:

```Shell
OPENAI_API_KEY=<your-openai-api-key>
```

Do this before building your docker containers. Or, if you've already run `docker-compose up`, you can run the
following to rebuild the server and worker containers.

```shell
docker-compose up --build -d --no-deps lexyserver lexyworker
```

### PyCharm issues

If your virtualenv keeps getting bjorked by PyCharm, make sure that you're following the instructions above verbatim,
Expand Down
5 changes: 4 additions & 1 deletion lexy/api/endpoints/index_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ async def query_records(query_text: str = Form(None),
detail=f"Transformer '{embedding_model}' not found")
task = celery.send_task(transformer.celery_task_name, args=[query], priority=10)
result = task.get()
query_embedding = result.tolist()
if isinstance(result, list):
query_embedding = result
else:
query_embedding = result.tolist()

# get index fields to return
if return_fields:
Expand Down
2 changes: 2 additions & 0 deletions lexy/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ class GlobalConfig(BaseConfig):
'lexy.transformers.counter',
'lexy.transformers.embeddings',
'lexy.transformers.multimodal',
'lexy.transformers.openai',
}
lexy_worker_transformer_imports = {
# 'lexy.transformers.*'
'lexy.transformers.counter',
'lexy.transformers.embeddings',
'lexy.transformers.multimodal',
'lexy.transformers.openai',
}

@property
Expand Down
5 changes: 5 additions & 0 deletions lexy/db/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
"path": "lexy.transformers.multimodal.image_embeddings_clip",
"description": "Image embeddings using \"openai/clip-vit-base-patch32\""
},
{
"transformer_id": "text.embeddings.openai-ada-002",
"path": "lexy.transformers.openai.text_embeddings_openai",
"description": "Text embeddings using OpenAI's \"text-embedding-ada-002\" model"
},
{
"transformer_id": "text.counter.word_counter",
"path": "lexy.transformers.counter.word_counter",
Expand Down
2 changes: 1 addition & 1 deletion lexy/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TransformerBase(SQLModel):
primary_key=True,
min_length=1,
max_length=255,
regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$"
regex=r"^[a-zA-Z][a-zA-Z0-9_.-]+$"
)
path: Optional[str] = Field(
default=None,
Expand Down
57 changes: 57 additions & 0 deletions lexy/transformers/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import os

from openai import OpenAI

from lexy.models.document import DocumentBase
from lexy.transformers import lexy_transformer


logger = logging.getLogger(__name__)
if os.environ.get("OPENAI_API_KEY"):
openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
else:
openai_client = None
logger.warning("OPENAI_API_KEY not set; cannot use OpenAI API")


@lexy_transformer(name="text.embeddings.openai-ada-002")
def text_embeddings_openai(text: list[str | DocumentBase] | str | DocumentBase, **kwargs) \
-> list[list[float]] | list[float]:
"""Embed text using OpenAI's API.

Any additional keyword arguments are passed to the client's `OpenAI.embeddings.create` method.

Args:
text: A single string or DocumentBase instance, or a list of strings or DocumentBase instances to embed.

Keyword Args:
encoding_format: The format to return the embeddings in. Can be either float or
[base64](https://pypi.org/project/pybase64/).
user: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
extra_headers: Send extra headers
extra_query: Add additional query parameters to the request
extra_body: Add additional JSON properties to the request
timeout: Override the client-level default timeout for this request, in seconds

Returns:
list[list[float]] | list[float]: The embeddings of the provided text.
"""
if not openai_client:
raise Exception("OPENAI_API_KEY not set; cannot use OpenAI API")

if isinstance(text, DocumentBase):
text = text.content
elif isinstance(text, list):
text = [s.content if isinstance(s, DocumentBase) else s for s in text]

api_response = openai_client.embeddings.create(
model="text-embedding-ada-002",
input=text,
**kwargs
)

if isinstance(text, list):
return [e.embedding for e in api_response.data]
else:
return api_response.data[0].embedding
38 changes: 36 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ Pillow = "^10.0.1"
# optional dependencies
transformers = { version = "^4.33.1", extras = ["torch"], optional = true}
sentence-transformers = { version = "^2.2.2", optional = true}
openai = { version = "^1.7.1", optional = true}

[tool.poetry.extras]
lexy_transformers = ["transformers", "sentence-transformers"]
lexy_transformers = ["transformers", "sentence-transformers", "openai"]

[tool.poetry.group.dev.dependencies]
alembic = "^1.13.1"
Expand Down
2 changes: 1 addition & 1 deletion sdk-python/lexy_py/transformer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class TransformerModel(BaseModel):
""" Transformer model """
transformer_id: str = Field(..., min_length=1, max_length=255, regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$")
transformer_id: str = Field(..., min_length=1, max_length=255, regex=r"^[a-zA-Z][a-zA-Z0-9_.-]+$")
path: Optional[str] = Field(..., min_length=1, max_length=255, regex=r"^[a-zA-Z][a-zA-Z0-9_.]+$")
description: Optional[str] = None
created_at: Optional[datetime] = None
Expand Down
0