8000 Update imodel by ohdearquant · Pull Request #539 · khive-ai/lionagi · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Update imodel #539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion lionagi/operations/ReAct/ReAct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ async def ReAct(
branch: "Branch",
instruct: Instruct | dict[str, Any],
interpret: bool = False,
interpret_domain: str | None = None,
interpret_style: str | None = None,
interpret_sample: str | None = None,
interpret_kwargs: dict | None = None,
tools: Any = None,
tool_schemas: Any = None,
response_format: type[BaseModel] | BaseModel = None,
Expand All @@ -43,7 +47,11 @@ async def ReAct(
instruct.to_dict()
if isinstance(instruct, Instruct)
else instruct
)
),
domain=interpret_domain,
style=interpret_style,
sample_writing=interpret_sample,
**(interpret_kwargs or {}),
)

# Convert Instruct to dict if necessary
Expand Down
11 changes: 11 additions & 0 deletions lionagi/service/endpoints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EndpointConfig(BaseModel):
use_enum_values=True,
)

name: str | None = None
provider: str | None = None
base_url: str | None = None
endpoint: str
Expand All @@ -75,6 +76,7 @@ class EndpointConfig(BaseModel):
requires_tokens: bool = False
api_version: str | None = None
allowed_roles: list[str] | None = None
request_options: type | None = None


class EndPoint(ABC):
Expand All @@ -100,6 +102,11 @@ def update_config(self, **kwargs):
config.update(kwargs)
self.config = EndpointConfig(**config)

@property
def name(self) -> str | None:
"""str | None: The name of the endpoint, if any."""
return self.config.name or self.endpoint

@property
def is_streamable(self) -> bool:
"""bool: Whether this endpoint supports streaming responses."""
Expand Down Expand Up @@ -185,6 +192,10 @@ def roled(self) -> bool:
"""bool: Indicates if this endpoint uses role-based messages."""
return self.allowed_roles is not None

@property
def request_options(self) -> type | None:
return self.config.request_options

def create_payload(self, **kwargs) -> dict:
"""Generates a request payload (and headers) for this endpoint.

Expand Down
11 changes: 11 additions & 0 deletions lionagi/service/imodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import warnings

from pydantic import BaseModel

from .endpoints.base import APICalling, EndPoint
from .endpoints.match_endpoint import match_endpoint
from .endpoints.rate_limited_processor import RateLimitedAPIExecutor
Expand Down Expand Up @@ -266,6 +268,15 @@ def model_name(self) -> str:
"""
return self.kwargs.get("model", "")

@property
def request_options(self) -> type[BaseModel] | None:
"""type[BaseModel] | None: The request options model for the endpoint.

Returns:
The request options model if available; otherwise, None.
"""
return self.endpoint.request_options

def to_dict(self):
kwargs = self.kwargs
if "kwargs" in self.kwargs:
Expand Down
8 changes: 4 additions & 4 deletions lionagi/service/providers/exa_/search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import TYPE_CHECKING, Literal
from typing import Literal

from lionagi.service.endpoints.base import EndPoint

if TYPE_CHECKING:
from .models import ExaSearchRequest

from .models import ExaSearchRequest

CATEGORY_OPTIONS = Literal[
"article",
Expand All @@ -21,6 +19,7 @@
]

SEARCH_CONFIG = {
"name": "search_exa",
"provider": "exa",
"base_url": "https://api.exa.ai",
"endpoint": "search",
Expand All @@ -47,6 +46,7 @@
"type", # keyword, neural, auto
"useAutoPrompt",
},
"request_options": ExaSearchRequest,
}


Expand Down
4 changes: 4 additions & 0 deletions lionagi/service/providers/perplexity_/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from lionagi.service.endpoints.chat_completion import ChatCompletionEndPoint

from .models import PerplexityChatCompletionRequest

CHAT_COMPLETION_CONFIG = {
"name": "search_perplexity",
"provider": "perplexity",
"base_url": "https://api.perplexity.ai",
"endpoint": "chat/completions",
Expand All @@ -31,6 +34,7 @@
"frequency_penalty",
},
"allowed_roles": ["user", "assistant"],
"request_options": PerplexityChatCompletionRequest,
}


Expand Down
144 changes: 144 additions & 0 deletions lionagi/service/providers/perplexity_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field, model_validator


class PerplexityRole(str, Enum):
"""Roles allowed in Perplexity's messages."""

system = "system"
user = "user"
assistant = "assistant"


class PerplexityMessage(BaseModel):
"""
A single message in the conversation.
`role` can be 'system', 'user', or 'assistant'.
`content` is the text for that conversation turn.
"""

role: PerplexityRole = Field(
...,
description="The role of the speaker. Must be system, user, or assistant.",
)
content: str = Field(..., description="The text content of this message.")


class PerplexityChatCompletionRequest(BaseModel):
"""
Represents the request body for Perplexity's Chat Completions endpoint.
Endpoint: POST https://api.perplexity.ai/chat/completions
"""

model: str = Field(
...,
description="The model name, e.g. 'llama-3.1-sonar-small-128k-online'.",
)
messages: list[PerplexityMessage] = Field(
..., description="A list of messages forming the conversation so far."
)

# Optional parameters
frequency_penalty: float | None = Field(
default=None,
gt=0,
description=(
"Multiplicative penalty > 0. Values > 1.0 penalize repeated tokens more strongly. "
"Value=1.0 means no penalty. Incompatible with presence_penalty."
),
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description=(
"Penalizes tokens that have appeared so far (range -2 to 2). "
"Positive values encourage talking about new topics. Incompatible with frequency_penalty."
),
)
max_tokens: int | None = Field(
default=None,
description=(
"Maximum number of completion tokens. If omitted, model generates tokens until it "
"hits stop or context limit."
),
)
return_images: bool | None = Field(
default=None,
description="If True, attempt to return images (closed beta feature).",
)
return_related_questions: bool | None = Field(
default=None,
description="If True, attempt to return related questions (closed beta feature).",
)
search_domain_filter: list[Any] | None = Field(
default=None,
description=(
"List of domains to limit or exclude in the online search. Example: ['example.com', '-twitter.com']. "
"Supports up to 3 entries. (Closed beta feature.)"
),
)
search_recency_filter: str | None = Field(
default=None,
description=(
"Returns search results within a specified time interval: 'month', 'week', 'day', or 'hour'."
),
)
stream: bool | None = Field(
default=None,
description=(
"If True, response is returned incrementally via Server-Sent Events (SSE)."
),
)
temperature: float | None = Field(
default=None,
ge=0.0,
lt=2.0,
description=(
"Controls randomness of sampling, range [0, 2). Higher => more random. "
"Defaults to 0.2."
),
)
top_k: int | None = Field(
default=None,
ge=0,
le=2048,
description=(
"Top-K filtering. 0 disables top-k filtering. If set, only the top K tokens are considered. "
"We recommend altering either top_k or top_p, but not both."
),
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description=(
"Nucleus sampling threshold. We recommend altering either top_k or top_p, but not both."
),
)

@model_validator(mode="before")
def validate_penalties(cls, values):
"""
Disallow using both frequency_penalty != 1.0 and presence_penalty != 0.0 at once,
since the docs say they're incompatible.
"""
freq_pen = values.get("frequency_penalty", 1.0)
pres_pen = values.get("presence_penalty", 0.0)

# The doc states frequency_penalty is incompatible with presence_penalty.
# We'll enforce that if presence_penalty != 0, frequency_penalty must be 1.0
# or vice versa. Adjust logic as needed.
if pres_pen != 0.0 and freq_pen != 1.0:
raise ValueError(
"presence_penalty is incompatible with frequency_penalty. "
"Please use only one: either presence_penalty=0 with freq_pen !=1, "
"or presence_penalty!=0 with freq_pen=1."
)
return values

def to_dict(self) -> dict:
"""Return a dict suitable for JSON serialization and sending to Perplexity API."""
return self.model_dump(exclude_none=True)
53 changes: 48 additions & 5 deletions lionagi/session/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
SenderRecipient,
System,
)
from lionagi.service.endpoints.base import EndPoint
from lionagi.service.types import iModel, iModelManager
from lionagi.settings import Settings
from lionagi.utils import UNDEFINED, alcall, bcall, copy
Expand Down Expand Up @@ -548,12 +549,38 @@ def receive_all(self) -> None:

def connect(
self,
name: str,
imodel: iModel,
request_options: type[BaseModel],
provider: str = None,
base_url: str = None,
endpoint: str | EndPoint = "chat",
endpoint_params: list[str] | None = None,
api_key: str = None,
queue_capacity: int = 100,
capacity_refresh_time: float = 60,
interval: float | None = None,
limit_requests: int = None,
limit_tokens: int = None,
invoke_with_endpoint: bool = False,
imodel: iModel = None,
name: str = None,
request_options: type[BaseModel] = None,
description: str = None,
update: bool = False,
):
if not imodel:
imodel = iModel(
provider=provider,
base_url=base_url,
endpoint=endpoint,
endpoint_params=endpoint_params,
api_key=api_key,
queue_capacity=queue_capacity,
capacity_refresh_time=capacity_refresh_time,
interval=interval,
limit_requests=limit_requests,
limit_tokens=limit_tokens,
invoke_with_endpoint=invoke_with_endpoint,
)

if not update and name in self.tools:
raise ValueError(f"Tool with name '{name}' already exists.")

Expand All @@ -563,13 +590,13 @@ async def _connect(**kwargs):
self._log_manager.log(Log.create(api_call))
return api_call.response

_connect.__name__ = name
_connect.__name__ = name or imodel.endpoint.name
if description:
_connect.__doc__ = description

tool = Tool(
func_callable=_connect,
request_options=request_options,
request_options=request_options or imodel.request_options,
)
self._action_manager.register_tools(tool, update=update)

Expand Down Expand Up @@ -1521,6 +1548,10 @@ async def ReAct(
self,
instruct: Instruct | dict[str, Any],
interpret: bool = False,
interpret_domain: str | None = None,
interpret_style: str | None = None,
interpret_sample: str | None = None,
interpret_kwargs: dict | None = None,
tools: Any = None,
tool_schemas: Any = None,
response_format: type[BaseModel] = None,
Expand All @@ -1547,6 +1578,14 @@ async def ReAct(
interpret (bool, optional):
If `True`, first interprets (`branch.interpret`) the instructions to refine them
before proceeding. Defaults to `False`.
interpret_domain (str | None, optional):
Optional domain hint for the interpretation step.
interpret_style (str | None, optional):
Optional style hint for the interpretation step.
interpret_sample (str | None, optional):
Optional sample hint for the interpretation step.
interpret_kwargs (dict | None, optional):
Additional arguments for the interpretation step.
tools (Any, optional):
Tools to be made available for the ReAct process. If omitted or `None`,
and if no `tool_schemas` are provided, it defaults to `True` (all tools).
Expand Down Expand Up @@ -1595,6 +1634,10 @@ async def ReAct(
self,
instruct,
interpret=interpret,
interpret_domain=interpret_domain,
interpret_style=interpret_style,
interpret_sample=interpret_sample,
interpret_kwargs=interpret_kwargs,
tools=tools,
tool_schemas=tool_schemas,
response_format=response_format,
Expand Down
2 changes: 1 addition & 1 deletion lionagi/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.8.4"
__version__ = "0.8.5"
Loading
Loading
0