From 501660dd736139b7ef539a510546e8add05efa9e Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Thu, 12 Dec 2024 22:25:19 -0500 Subject: [PATCH 1/7] prototype for modal serving --- sotopia/ui/fastapi_server.py | 680 +++++++++++++++++---------------- sotopia/ui/modal_api_server.py | 116 ++++++ 2 files changed, 465 insertions(+), 331 deletions(-) create mode 100644 sotopia/ui/modal_api_server.py diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index 611878c4c..fdc46ce02 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -51,15 +51,19 @@ logger = logging.getLogger(__name__) -app = FastAPI() +# app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) # TODO: Whether allowing CORS for all origins +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["*"], +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) # TODO: Whether allowing CORS for all origins + +active_simulations: Dict[ + str, bool +] = {} # TODO check whether this is the correct way to store the active simulations class RelationshipWrapper(BaseModel): @@ -139,120 +143,114 @@ def validate_models(self) -> Self: return self -@app.get("/scenarios", response_model=list[EnvironmentProfile]) -async def get_scenarios_all() -> list[EnvironmentProfile]: - return EnvironmentProfile.all() - +class SimulationState: + _instance: Optional["SimulationState"] = None + _lock = asyncio.Lock() + _active_simulations: dict[str, bool] = {} -@app.get("/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile]) -async def get_scenarios( - get_by: Literal["id", "codename"], value: str -) -> list[EnvironmentProfile]: - # Implement logic to fetch scenarios based on the parameters - scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic - if get_by == "id": - scenarios.append(EnvironmentProfile.get(pk=value)) - elif get_by == "codename": - json_models = EnvironmentProfile.find( - EnvironmentProfile.codename == value - ).all() - scenarios.extend(cast(list[EnvironmentProfile], json_models)) + def __new__(cls) -> "SimulationState": + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._active_simulations = {} + return cls._instance - if not scenarios: - raise HTTPException( - status_code=404, detail=f"No scenarios found with {get_by}={value}" - ) + async def try_acquire_token(self, token: str) -> tuple[bool, str]: + async with self._lock: + if not token: + return False, "Invalid token" - return scenarios + if self._active_simulations.get(token): + return False, "Token is active already" + self._active_simulations[token] = True + return True, "Token is valid" -@app.get("/agents", response_model=list[AgentProfile]) -async def get_agents_all() -> list[AgentProfile]: - return AgentProfile.all() + async def release_token(self, token: str) -> None: + async with self._lock: + self._active_simulations.pop(token, None) + @asynccontextmanager + async def start_simulation(self, token: str) -> AsyncIterator[bool]: + try: + yield True + finally: + await self.release_token(token) -@app.get("/agents/{get_by}/{value}", response_model=list[AgentProfile]) -async def get_agents( - get_by: Literal["id", "gender", "occupation"], value: str -) -> list[AgentProfile]: - agents_profiles: list[AgentProfile] = [] - if get_by == "id": - agents_profiles.append(AgentProfile.get(pk=value)) - elif get_by == "gender": - json_models = AgentProfile.find(AgentProfile.gender == value).all() - agents_profiles.extend(cast(list[AgentProfile], json_models)) - elif get_by == "occupation": - json_models = AgentProfile.find(AgentProfile.occupation == value).all() - agents_profiles.extend(cast(list[AgentProfile], json_models)) - if not agents_profiles: - raise HTTPException( - status_code=404, detail=f"No agents found with {get_by}={value}" - ) +class SimulationManager: + def __init__(self) -> None: + self.state = SimulationState() - return agents_profiles + async def verify_token(self, token: str) -> dict[str, Any]: + is_valid, msg = await self.state.try_acquire_token(token) + return {"is_valid": is_valid, "msg": msg} + async def create_simulator( + self, env_id: str, agent_ids: list[str] + ) -> WebSocketSotopiaSimulator: + try: + return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) + except Exception as e: + error_msg = f"Failed to create simulator: {e}" + logger.error(error_msg) + raise Exception(error_msg) -@app.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str) -async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: - relationship_profiles = RelationshipProfile.find( - (RelationshipProfile.agent_1_id == agent_1_id) - & (RelationshipProfile.agent_2_id == agent_2_id) - ).all() - assert ( - len(relationship_profiles) == 1 - ), f"{len(relationship_profiles)} relationship profiles found for agents {agent_1_id} and {agent_2_id}, expected 1" - relationship_profile = relationship_profiles[0] - assert isinstance(relationship_profile, RelationshipProfile) - return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" + async def handle_client_message( + self, + websocket: WebSocket, + simulator: WebSocketSotopiaSimulator, + message: dict[str, Any], + timeout: float = 0.1, + ) -> bool: + try: + msg_type = message.get("type") + if msg_type == WSMessageType.FINISH_SIM.value: + return True + # TODO handle other message types + return False + except Exception as e: + msg = f"Error handling client message: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) + return False + async def run_simulation( + self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator + ) -> None: + try: + async for message in simulator.arun(): + await self.send_message(websocket, WSMessageType.SERVER_MSG, message) -@app.get("/episodes", response_model=list[EpisodeLog]) -async def get_episodes_all() -> list[EpisodeLog]: - return EpisodeLog.all() + try: + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) + if await self.handle_client_message(websocket, simulator, data): + break + except asyncio.TimeoutError: + continue + except Exception as e: + msg = f"Error running simulation: {e}" + logger.error(msg) + await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) + finally: + await self.send_message(websocket, WSMessageType.END_SIM, {}) -@app.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog]) -async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: - episodes: list[EpisodeLog] = [] - if get_by == "id": - episodes.append(EpisodeLog.get(pk=value)) - elif get_by == "tag": - json_models = EpisodeLog.find(EpisodeLog.tag == value).all() - episodes.extend(cast(list[EpisodeLog], json_models)) + @staticmethod + async def send_message( + websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] + ) -> None: + await websocket.send_json({"type": msg_type.value, "data": data}) - if not episodes: - raise HTTPException( - status_code=404, detail=f"No episodes found with {get_by}={value}" + @staticmethod + async def send_error( + websocket: WebSocket, error_type: ErrorType, details: str = "" + ) -> None: + await websocket.send_json( + { + "type": WSMessageType.ERROR.value, + "data": {"type": error_type.value, "details": details}, + } ) - return episodes - - -@app.post("/scenarios/", response_model=str) -async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: - scenario_profile = EnvironmentProfile(**scenario.model_dump()) - scenario_profile.save() - pk = scenario_profile.pk - assert pk is not None - return pk - - -@app.post("/agents/", response_model=str) -async def create_agent(agent: AgentProfileWrapper) -> str: - agent_profile = AgentProfile(**agent.model_dump()) - agent_profile.save() - pk = agent_profile.pk - assert pk is not None - return pk - - -@app.post("/relationship/", response_model=str) -async def create_relationship(relationship: RelationshipWrapper) -> str: - relationship_profile = RelationshipProfile(**relationship.model_dump()) - relationship_profile.save() - pk = relationship_profile.pk - assert pk is not None - return pk async def run_simulation( @@ -325,263 +323,283 @@ async def run_simulation( ) -@app.post("/simulate/", response_model=str) -def simulate(simulation_request: SimulationRequest) -> Response: - try: - _: EnvironmentProfile = EnvironmentProfile.get(pk=simulation_request.env_id) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, - detail=f"Environment with id={simulation_request.env_id} not found", - ) - try: - __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, - detail=f"Agent with id={simulation_request.agent_ids[0]} not found", - ) - try: - ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) - except Exception: # TODO Check the exception type +async def get_scenarios_all() -> list[EnvironmentProfile]: + return EnvironmentProfile.all() + + +async def get_scenarios( + get_by: Literal["id", "codename"], value: str +) -> list[EnvironmentProfile]: + # Implement logic to fetch scenarios based on the parameters + scenarios: list[EnvironmentProfile] = [] # Replace with actual fetching logic + if get_by == "id": + scenarios.append(EnvironmentProfile.get(pk=value)) + elif get_by == "codename": + json_models = EnvironmentProfile.find( + EnvironmentProfile.codename == value + ).all() + scenarios.extend(cast(list[EnvironmentProfile], json_models)) + + if not scenarios: raise HTTPException( - status_code=404, - detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + status_code=404, detail=f"No scenarios found with {get_by}={value}" ) - episode_pk = EpisodeLog( - environment="", - agents=[], - models=[], - messages=[], - reasoning="", - rewards=[], # Pseudorewards - rewards_prompt="", - ).pk - try: - simulation_status = NonStreamingSimulationStatus( - episode_pk=episode_pk, - status="Started", - ) - simulation_status.save() - queue = rq.Queue("default", connection=get_redis_connection()) - queue.enqueue( - run_simulation, - episode_pk=episode_pk, - simulation_request=simulation_request, - simulation_status=simulation_status, - ) + return scenarios - except Exception as e: - logger.error(f"Error starting simulation: {e}") - simulation_status.status = "Error" - simulation_status.save() - return Response(content=episode_pk, status_code=202) +async def get_agents_all() -> list[AgentProfile]: + return AgentProfile.all() -@app.get("/simulation_status/{episode_pk}", response_model=str) -async def get_simulation_status(episode_pk: str) -> str: - status = NonStreamingSimulationStatus.find( - NonStreamingSimulationStatus.episode_pk == episode_pk - ).all()[0] - assert isinstance(status, NonStreamingSimulationStatus) - return status.status +async def get_agents( + get_by: Literal["id", "gender", "occupation"], value: str +) -> list[AgentProfile]: + agents_profiles: list[AgentProfile] = [] + if get_by == "id": + agents_profiles.append(AgentProfile.get(pk=value)) + elif get_by == "gender": + json_models = AgentProfile.find(AgentProfile.gender == value).all() + agents_profiles.extend(cast(list[AgentProfile], json_models)) + elif get_by == "occupation": + json_models = AgentProfile.find(AgentProfile.occupation == value).all() + agents_profiles.extend(cast(list[AgentProfile], json_models)) -@app.delete("/agents/{agent_id}", response_model=str) -async def delete_agent(agent_id: str) -> str: - try: - agent = AgentProfile.get(pk=agent_id) - except Exception: # TODO Check the exception type + if not agents_profiles: raise HTTPException( - status_code=404, detail=f"Agent with id={agent_id} not found" + status_code=404, detail=f"No agents found with {get_by}={value}" ) - AgentProfile.delete(agent.pk) - assert agent.pk is not None - return agent.pk + return agents_profiles -@app.delete("/scenarios/{scenario_id}", response_model=str) -async def delete_scenario(scenario_id: str) -> str: - try: - scenario = EnvironmentProfile.get(pk=scenario_id) - except Exception: # TODO Check the exception type - raise HTTPException( - status_code=404, detail=f"Scenario with id={scenario_id} not found" - ) - EnvironmentProfile.delete(scenario.pk) - assert scenario.pk is not None - return scenario.pk +async def get_relationship(agent_1_id: str, agent_2_id: str) -> str: + relationship_profiles = RelationshipProfile.find( + (RelationshipProfile.agent_1_id == agent_1_id) + & (RelationshipProfile.agent_2_id == agent_2_id) + ).all() + assert ( + len(relationship_profiles) == 1 + ), f"{len(relationship_profiles)} relationship profiles found for agents {agent_1_id} and {agent_2_id}, expected 1" + relationship_profile = relationship_profiles[0] + assert isinstance(relationship_profile, RelationshipProfile) + return f"{str(relationship_profile.relationship)}: {RelationshipType(relationship_profile.relationship).name}" -@app.delete("/relationship/{relationship_id}", response_model=str) -async def delete_relationship(relationship_id: str) -> str: - RelationshipProfile.delete(relationship_id) - return relationship_id +async def get_episodes_all() -> list[EpisodeLog]: + return EpisodeLog.all() -@app.delete("/episodes/{episode_id}", response_model=str) -async def delete_episode(episode_id: str) -> str: - EpisodeLog.delete(episode_id) - return episode_id +async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[EpisodeLog]: + episodes: list[EpisodeLog] = [] + if get_by == "id": + episodes.append(EpisodeLog.get(pk=value)) + elif get_by == "tag": + json_models = EpisodeLog.find(EpisodeLog.tag == value).all() + episodes.extend(cast(list[EpisodeLog], json_models)) -active_simulations: Dict[ - str, bool -] = {} # TODO check whether this is the correct way to store the active simulations + if not episodes: + raise HTTPException( + status_code=404, detail=f"No episodes found with {get_by}={value}" + ) + return episodes -@app.get("/models", response_model=list[str]) async def get_models() -> list[str]: # TODO figure out how to get the available models return ["gpt-4o-mini", "gpt-4o", "gpt-3.5-turbo"] -class SimulationState: - _instance: Optional["SimulationState"] = None - _lock = asyncio.Lock() - _active_simulations: dict[str, bool] = {} - - def __new__(cls) -> "SimulationState": - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._active_simulations = {} - return cls._instance - - async def try_acquire_token(self, token: str) -> tuple[bool, str]: - async with self._lock: - if not token: - return False, "Invalid token" - - if self._active_simulations.get(token): - return False, "Token is active already" - - self._active_simulations[token] = True - return True, "Token is valid" - - async def release_token(self, token: str) -> None: - async with self._lock: - self._active_simulations.pop(token, None) - - @asynccontextmanager - async def start_simulation(self, token: str) -> AsyncIterator[bool]: - try: - yield True - finally: - await self.release_token(token) - - -class SimulationManager: - def __init__(self) -> None: - self.state = SimulationState() - - async def verify_token(self, token: str) -> dict[str, Any]: - is_valid, msg = await self.state.try_acquire_token(token) - return {"is_valid": is_valid, "msg": msg} - - async def create_simulator( - self, env_id: str, agent_ids: list[str] - ) -> WebSocketSotopiaSimulator: - try: - return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) - except Exception as e: - error_msg = f"Failed to create simulator: {e}" - logger.error(error_msg) - raise Exception(error_msg) - - async def handle_client_message( - self, - websocket: WebSocket, - simulator: WebSocketSotopiaSimulator, - message: dict[str, Any], - timeout: float = 0.1, - ) -> bool: - try: - msg_type = message.get("type") - if msg_type == WSMessageType.FINISH_SIM.value: - return True - # TODO handle other message types - return False - except Exception as e: - msg = f"Error handling client message: {e}" - logger.error(msg) - await self.send_error(websocket, ErrorType.INVALID_MESSAGE, msg) - return False - - async def run_simulation( - self, websocket: WebSocket, simulator: WebSocketSotopiaSimulator - ) -> None: - try: - async for message in simulator.arun(): - await self.send_message(websocket, WSMessageType.SERVER_MSG, message) - - try: - data = await asyncio.wait_for(websocket.receive_json(), timeout=0.1) - if await self.handle_client_message(websocket, simulator, data): - break - except asyncio.TimeoutError: - continue - - except Exception as e: - msg = f"Error running simulation: {e}" - logger.error(msg) - await self.send_error(websocket, ErrorType.SIMULATION_ISSUE, msg) - finally: - await self.send_message(websocket, WSMessageType.END_SIM, {}) - - @staticmethod - async def send_message( - websocket: WebSocket, msg_type: WSMessageType, data: dict[str, Any] - ) -> None: - await websocket.send_json({"type": msg_type.value, "data": data}) - - @staticmethod - async def send_error( - websocket: WebSocket, error_type: ErrorType, details: str = "" - ) -> None: - await websocket.send_json( - { - "type": WSMessageType.ERROR.value, - "data": {"type": error_type.value, "details": details}, - } +class SotopiaFastAPI(FastAPI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) + self.setup_routes() + def setup_routes(self): + self.get("/scenarios", response_model=list[EnvironmentProfile])( + get_scenarios_all + ) + self.get( + "/scenarios/{get_by}/{value}", response_model=list[EnvironmentProfile] + )(get_scenarios) + self.get("/agents", response_model=list[AgentProfile])(get_agents_all) + self.get("/agents/{get_by}/{value}", response_model=list[AgentProfile])( + get_agents + ) + self.get("/relationship/{agent_1_id}/{agent_2_id}", response_model=str)( + get_relationship + ) + self.get("/episodes", response_model=list[EpisodeLog])(get_episodes_all) + self.get("/episodes/{get_by}/{value}", response_model=list[EpisodeLog])( + get_episodes + ) + self.get("/models", response_model=list[str])(get_models) + + @self.post("/scenarios/", response_model=str) + async def create_scenario(scenario: EnvironmentProfileWrapper) -> str: + scenario_profile = EnvironmentProfile(**scenario.model_dump()) + scenario_profile.save() + pk = scenario_profile.pk + assert pk is not None + return pk + + @self.post("/agents/", response_model=str) + async def create_agent(agent: AgentProfileWrapper) -> str: + agent_profile = AgentProfile(**agent.model_dump()) + agent_profile.save() + pk = agent_profile.pk + assert pk is not None + return pk + + @self.post("/relationship/", response_model=str) + async def create_relationship(relationship: RelationshipWrapper) -> str: + relationship_profile = RelationshipProfile(**relationship.model_dump()) + relationship_profile.save() + pk = relationship_profile.pk + assert pk is not None + return pk + + @self.post("/simulate/", response_model=str) + def simulate(simulation_request: SimulationRequest) -> Response: + try: + _: EnvironmentProfile = EnvironmentProfile.get( + pk=simulation_request.env_id + ) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Environment with id={simulation_request.env_id} not found", + ) + try: + __ = AgentProfile.get(pk=simulation_request.agent_ids[0]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[0]} not found", + ) + try: + ___ = AgentProfile.get(pk=simulation_request.agent_ids[1]) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, + detail=f"Agent with id={simulation_request.agent_ids[1]} not found", + ) -@app.websocket("/ws/simulation") -async def websocket_endpoint(websocket: WebSocket, token: str) -> None: - manager = SimulationManager() - - token_status = await manager.verify_token(token) - if not token_status["is_valid"]: - await websocket.close(code=1008, reason=token_status["msg"]) - return - - try: - await websocket.accept() - - while True: - start_msg = await websocket.receive_json() - if start_msg.get("type") != WSMessageType.START_SIM.value: - continue + episode_pk = EpisodeLog( + environment="", + agents=[], + models=[], + messages=[], + reasoning="", + rewards=[], # Pseudorewards + rewards_prompt="", + ).pk + try: + simulation_status = NonStreamingSimulationStatus( + episode_pk=episode_pk, + status="Started", + ) + simulation_status.save() + queue = rq.Queue("default", connection=get_redis_connection()) + queue.enqueue( + run_simulation, + episode_pk=episode_pk, + simulation_request=simulation_request, + simulation_status=simulation_status, + ) - async with manager.state.start_simulation(token): - simulator = await manager.create_simulator( - env_id=start_msg["data"]["env_id"], - agent_ids=start_msg["data"]["agent_ids"], + except Exception as e: + logger.error(f"Error starting simulation: {e}") + simulation_status.status = "Error" + simulation_status.save() + return Response(content=episode_pk, status_code=202) + + @self.get("/simulation_status/{episode_pk}", response_model=str) + async def get_simulation_status(episode_pk: str) -> str: + status = NonStreamingSimulationStatus.find( + NonStreamingSimulationStatus.episode_pk == episode_pk + ).all()[0] + assert isinstance(status, NonStreamingSimulationStatus) + return status.status + + @self.delete("/agents/{agent_id}", response_model=str) + async def delete_agent(agent_id: str) -> str: + try: + agent = AgentProfile.get(pk=agent_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Agent with id={agent_id} not found" ) - await manager.run_simulation(websocket, simulator) - - except WebSocketDisconnect: - logger.info(f"Client disconnected: {token}") - except Exception as e: - logger.error(f"Unexpected error: {e}") - await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) - finally: - try: - await websocket.close() - except Exception as e: - logger.error(f"Error closing websocket: {e}") + AgentProfile.delete(agent.pk) + assert agent.pk is not None + return agent.pk + + @self.delete("/scenarios/{scenario_id}", response_model=str) + async def delete_scenario(scenario_id: str) -> str: + try: + scenario = EnvironmentProfile.get(pk=scenario_id) + except Exception: # TODO Check the exception type + raise HTTPException( + status_code=404, detail=f"Scenario with id={scenario_id} not found" + ) + EnvironmentProfile.delete(scenario.pk) + assert scenario.pk is not None + return scenario.pk + + @self.delete("/relationship/{relationship_id}", response_model=str) + async def delete_relationship(relationship_id: str) -> str: + RelationshipProfile.delete(relationship_id) + return relationship_id + + @self.delete("/episodes/{episode_id}", response_model=str) + async def delete_episode(episode_id: str) -> str: + EpisodeLog.delete(episode_id) + return episode_id + + @self.websocket("/ws/simulation") + async def websocket_endpoint(websocket: WebSocket, token: str) -> None: + manager = SimulationManager() + + token_status = await manager.verify_token(token) + if not token_status["is_valid"]: + await websocket.close(code=1008, reason=token_status["msg"]) + return + + try: + await websocket.accept() + + while True: + start_msg = await websocket.receive_json() + if start_msg.get("type") != WSMessageType.START_SIM.value: + continue + + async with manager.state.start_simulation(token): + simulator = await manager.create_simulator( + env_id=start_msg["data"]["env_id"], + agent_ids=start_msg["data"]["agent_ids"], + ) + await manager.run_simulation(websocket, simulator) + + except WebSocketDisconnect: + logger.info(f"Client disconnected: {token}") + except Exception as e: + logger.error(f"Unexpected error: {e}") + await manager.send_error(websocket, ErrorType.SIMULATION_ISSUE, str(e)) + finally: + try: + await websocket.close() + except Exception as e: + logger.error(f"Error closing websocket: {e}") + +app = SotopiaFastAPI() if __name__ == "__main__": uvicorn.run(app, host="127.0.0.1", port=8800) diff --git a/sotopia/ui/modal_api_server.py b/sotopia/ui/modal_api_server.py new file mode 100644 index 000000000..203260cca --- /dev/null +++ b/sotopia/ui/modal_api_server.py @@ -0,0 +1,116 @@ +import modal +import subprocess +import time +import os + +import redis +from sotopia.ui.fastapi_server import SotopiaFastAPI + +# Create persistent volume for Redis data +redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True) + + +def initialize_redis_data(): + """Download Redis data if it doesn't exist""" + if not os.path.exists("/vol/redis/dump.rdb"): + os.makedirs("/vol/redis", exist_ok=True) + print("Downloading initial Redis data...") + subprocess.run( + "curl -L https://cmu.box.com/shared/static/xiivc5z8rnmi1zr6vmk1ohxslylvynur --output /vol/redis/dump.rdb", + shell=True, + check=True, + ) + print("Redis data downloaded") + + +# Create image with all necessary dependencies +image = ( + modal.Image.debian_slim(python_version="3.11") + .apt_install( + "git", + "curl", + "gpg", + "lsb-release", + "wget", + "procps", # for ps command + "redis-tools", # for redis-cli + ) + .run_commands( + # Update and install basic dependencies + "apt-get update", + "apt-get install -y curl gpg lsb-release", + # Add Redis Stack repository + "curl -fsSL https://packages.redis.io/gpg | gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg", + "chmod 644 /usr/share/keyrings/redis-archive-keyring.gpg", + 'echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/redis.list', + "apt-get update", + "apt-get install -y redis-stack-server", + ) + .pip_install( + "pydantic>=2.5.0,<3.0.0", + "aiohttp>=3.9.3,<4.0.0", + "rich>=13.8.1,<14.0.0", + "typer>=0.12.5", + "aiostream>=0.5.2", + "fastapi[all]", + "uvicorn", + "redis>=5.0.0", + "rq", + "lxml>=4.9.3,<6.0.0", + "openai>=1.11.0,<2.0.0", + "langchain>=0.2.5,<0.4.0", + "PettingZoo==1.24.3", + "redis-om>=0.3.0,<0.4.0", + "gin-config>=0.5.0,<0.6.0", + "absl-py>=2.0.0,<3.0.0", + "together>=0.2.4,<1.4.0", + "beartype>=0.14.0,<0.20.0", + "langchain-openai>=0.1.8,<0.2", + "hiredis>=3.0.0", + "aact", + "gin", + ) +) +redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True) + +# Create stub for the application +app = modal.App("sotopia-fastapi", image=image, volumes={"/vol/redis": redis_volume}) + + +@app.cls(image=image) +class WebAPI: + def __init__(self): + self.web_app = SotopiaFastAPI() + + @modal.enter() + def setup(self): + # Start Redis server + subprocess.Popen( + ["redis-stack-server", "--dir", "/vol/redis", "--port", "6379"] + ) + + # Wait for Redis to be ready + max_retries = 30 + for _ in range(max_retries): + try: + initialize_redis_data() + # Attempt to create Redis client and ping the server + temp_client = redis.Redis(host="localhost", port=6379, db=0) + temp_client.ping() + self.redis_client = temp_client + print("Successfully connected to Redis") + return + except (redis.exceptions.ConnectionError, redis.exceptions.ResponseError): + print("Waiting for Redis to be ready...") + time.sleep(1) + + raise Exception("Could not connect to Redis after multiple attempts") + + @modal.exit() + def cleanup(self): + if hasattr(self, "redis_client"): + self.redis_client.close() + + @modal.asgi_app() + def serve(self): + return self.web_app From f3340062ab38cd7b4f56582e6633b2d9f5781ef0 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Fri, 13 Dec 2024 15:42:46 -0500 Subject: [PATCH 2/7] add openai secret --- sotopia/ui/fastapi_server.py | 4 ++-- sotopia/ui/modal_api_server.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index fdc46ce02..a5b6d9d2d 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -411,7 +411,7 @@ async def get_models() -> list[str]: class SotopiaFastAPI(FastAPI): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: # type: ignore super().__init__(*args, **kwargs) self.add_middleware( CORSMiddleware, @@ -422,7 +422,7 @@ def __init__(self, *args, **kwargs): ) self.setup_routes() - def setup_routes(self): + def setup_routes(self) -> None: self.get("/scenarios", response_model=list[EnvironmentProfile])( get_scenarios_all ) diff --git a/sotopia/ui/modal_api_server.py b/sotopia/ui/modal_api_server.py index 203260cca..f012ce408 100644 --- a/sotopia/ui/modal_api_server.py +++ b/sotopia/ui/modal_api_server.py @@ -10,7 +10,7 @@ redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True) -def initialize_redis_data(): +def initialize_redis_data() -> None: """Download Redis data if it doesn't exist""" if not os.path.exists("/vol/redis/dump.rdb"): os.makedirs("/vol/redis", exist_ok=True) @@ -77,13 +77,18 @@ def initialize_redis_data(): app = modal.App("sotopia-fastapi", image=image, volumes={"/vol/redis": redis_volume}) -@app.cls(image=image) +@app.cls( + image=image, + concurrency_limit=1, + allow_concurrent_inputs=5, + secrets=[modal.Secret.from_name("openai-secret")], +) class WebAPI: - def __init__(self): + def __init__(self) -> None: self.web_app = SotopiaFastAPI() @modal.enter() - def setup(self): + def setup(self) -> None: # Start Redis server subprocess.Popen( ["redis-stack-server", "--dir", "/vol/redis", "--port", "6379"] @@ -107,10 +112,10 @@ def setup(self): raise Exception("Could not connect to Redis after multiple attempts") @modal.exit() - def cleanup(self): + def cleanup(self) -> None: if hasattr(self, "redis_client"): self.redis_client.close() @modal.asgi_app() - def serve(self): + def serve(self) -> modal.AsgiApp: return self.web_app From cef8a616263fbb9bfb553519efac68abf54c20fd Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sat, 14 Dec 2024 02:17:26 -0500 Subject: [PATCH 3/7] fix type annotation --- sotopia/ui/modal_api_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sotopia/ui/modal_api_server.py b/sotopia/ui/modal_api_server.py index f012ce408..5918a452d 100644 --- a/sotopia/ui/modal_api_server.py +++ b/sotopia/ui/modal_api_server.py @@ -117,5 +117,5 @@ def cleanup(self) -> None: self.redis_client.close() @modal.asgi_app() - def serve(self) -> modal.AsgiApp: + def serve(self) -> SotopiaFastAPI: return self.web_app From 1957b304c5e9ebb8474545a2362229daf1599eb4 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sat, 14 Dec 2024 02:20:44 -0500 Subject: [PATCH 4/7] add doc --- docs/pages/examples/deployment.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 docs/pages/examples/deployment.md diff --git a/docs/pages/examples/deployment.md b/docs/pages/examples/deployment.md new file mode 100644 index 000000000..6a8185701 --- /dev/null +++ b/docs/pages/examples/deployment.md @@ -0,0 +1,6 @@ +# Deploy Sotopia Python API to Modal +We offer a script to deploy Sotopia Python API to [Modal](https://modal.com/). +To do so, simply go to the `sotopia/sotopia/ui` directory and run the following command: +```bash +modal deploy sotopia/ui/modal_api_server.py +``` From 73d7f1db46dba7ba443d1142fdef2d329ccfe573 Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sat, 14 Dec 2024 03:20:00 -0500 Subject: [PATCH 5/7] bug fix for simulation api --- sotopia/ui/README.md | 9 ++++++++- sotopia/ui/websocket_utils.py | 18 +++++++++--------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/sotopia/ui/README.md b/sotopia/ui/README.md index 4d8bb773a..f483f3ea2 100644 --- a/sotopia/ui/README.md +++ b/sotopia/ui/README.md @@ -155,7 +155,7 @@ returns: **WSMessageType** | Type | Direction | Description | |-----------|--------|-------------| -| SERVER_MSG | Server → Client | Standard message from server (payload: `messageForRendering` [here](https://github.com/sotopia-lab/sotopia-demo/blob/main/socialstream/rendering_utils.py) ) | +| SERVER_MSG | Server → Client | Standard message from server (payload: `EpisodeLog`) | | CLIENT_MSG | Client → Server | Standard message from client (payload: TBD) | | ERROR | Server → Client | Error notification (payload: TBD) | | START_SIM | Client → Server | Initialize simulation (payload: `SimulationEpisodeInitialization`) | @@ -179,6 +179,13 @@ returns: **Implementation plan**: Currently only support LLM-LLM simulation based on [this function](https://github.com/sotopia-lab/sotopia/blob/19d39e068c3bca9246fc366e5759414f62284f93/sotopia/server.py#L108). +**SERVER_MSG payload** +The server message is a dictionary that has the following keys: +- type: str, indicates the type of the message, typically it is "messages" +- messages: Any. Typically this is the dictionary of the `EpisodeLog` for the current simulation state. (Which means the reward part could be empty) + + + ## An example to run simulation with the API diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py index 5b29da732..2e10c228d 100644 --- a/sotopia/ui/websocket_utils.py +++ b/sotopia/ui/websocket_utils.py @@ -151,17 +151,17 @@ async def arun(self) -> AsyncGenerator[dict[str, Any], None]: streaming=True, ) - assert isinstance( - generator, AsyncGenerator - ), "generator should be async generator" + # assert isinstance( + # generator, AsyncGenerator + # ), "generator should be async generator, but got {}".format( + # type(generator) + # ) async for messages in await generator: # type: ignore reasoning, rewards = "", [0.0, 0.0] - eval_available = False if messages[-1][0][0] == "Evaluation": reasoning = messages[-1][0][2].to_natural_language() rewards = eval(messages[-2][0][2].to_natural_language()) - eval_available = True epilog = EpisodeLog( environment=self.env.profile.pk, @@ -176,11 +176,11 @@ async def arun(self) -> AsyncGenerator[dict[str, Any], None]: rewards=rewards, rewards_prompt="", ) - agent_profiles, parsed_messages = epilog.render_for_humans() - if not eval_available: - parsed_messages = parsed_messages[:-2] + # agent_profiles, parsed_messages = epilog.render_for_humans() + # if not eval_available: + # parsed_messages = parsed_messages[:-2] yield { "type": "messages", - "messages": parsed_messages, + "messages": epilog.dict(), } From 19060e4b26f1be6458dcbb7a11efac0bb98c5c3f Mon Sep 17 00:00:00 2001 From: Zhe Su <360307598@qq.com> Date: Sun, 15 Dec 2024 00:29:22 -0500 Subject: [PATCH 6/7] add customize model, evaluator model and evaluation dimensions --- sotopia/ui/fastapi_server.py | 24 ++++++++++++++++++++++-- sotopia/ui/modal_api_server.py | 2 +- sotopia/ui/websocket_utils.py | 26 +++++++++++++++++++++----- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/sotopia/ui/fastapi_server.py b/sotopia/ui/fastapi_server.py index a5b6d9d2d..9f3483f15 100644 --- a/sotopia/ui/fastapi_server.py +++ b/sotopia/ui/fastapi_server.py @@ -186,10 +186,21 @@ async def verify_token(self, token: str) -> dict[str, Any]: return {"is_valid": is_valid, "msg": msg} async def create_simulator( - self, env_id: str, agent_ids: list[str] + self, + env_id: str, + agent_ids: list[str], + agent_models: list[str], + evaluator_model: str, + evaluation_dimension_list_name: str, ) -> WebSocketSotopiaSimulator: try: - return WebSocketSotopiaSimulator(env_id=env_id, agent_ids=agent_ids) + return WebSocketSotopiaSimulator( + env_id=env_id, + agent_ids=agent_ids, + agent_models=agent_models, + evaluator_model=evaluator_model, + evaluation_dimension_list_name=evaluation_dimension_list_name, + ) except Exception as e: error_msg = f"Failed to create simulator: {e}" logger.error(error_msg) @@ -584,6 +595,15 @@ async def websocket_endpoint(websocket: WebSocket, token: str) -> None: simulator = await manager.create_simulator( env_id=start_msg["data"]["env_id"], agent_ids=start_msg["data"]["agent_ids"], + agent_models=start_msg["data"].get( + "agent_models", ["gpt-4o-mini", "gpt-4o-mini"] + ), + evaluator_model=start_msg["data"].get( + "evaluator_model", "gpt-4o" + ), + evaluation_dimension_list_name=start_msg["data"].get( + "evaluation_dimension_list_name", "sotopia" + ), ) await manager.run_simulation(websocket, simulator) diff --git a/sotopia/ui/modal_api_server.py b/sotopia/ui/modal_api_server.py index 5918a452d..81c1c044b 100644 --- a/sotopia/ui/modal_api_server.py +++ b/sotopia/ui/modal_api_server.py @@ -16,7 +16,7 @@ def initialize_redis_data() -> None: os.makedirs("/vol/redis", exist_ok=True) print("Downloading initial Redis data...") subprocess.run( - "curl -L https://cmu.box.com/shared/static/xiivc5z8rnmi1zr6vmk1ohxslylvynur --output /vol/redis/dump.rdb", + "curl -L https://cmu.box.com/shared/static/e3vd31r7916jb70j9cgtcq9etryrxml0.rdb --output /vol/redis/dump.rdb", shell=True, check=True, ) diff --git a/sotopia/ui/websocket_utils.py b/sotopia/ui/websocket_utils.py index 2e10c228d..999a9981f 100644 --- a/sotopia/ui/websocket_utils.py +++ b/sotopia/ui/websocket_utils.py @@ -2,16 +2,20 @@ EvaluationForTwoAgents, ReachGoalLLMEvaluator, RuleBasedTerminatedEvaluator, - SotopiaDimensions, ) from sotopia.agents import Agents, LLMAgent from sotopia.messages import Observation from sotopia.envs import ParallelSotopiaEnv -from sotopia.database import EnvironmentProfile, AgentProfile, EpisodeLog +from sotopia.database import ( + EnvironmentProfile, + AgentProfile, + EpisodeLog, + EvaluationDimensionBuilder, +) from sotopia.server import arun_one_episode from enum import Enum -from typing import TypedDict, Any, AsyncGenerator +from typing import Type, TypedDict, Any, AsyncGenerator from pydantic import BaseModel @@ -57,6 +61,7 @@ def get_env_agents( agent_ids: list[str], agent_models: list[str], evaluator_model: str, + evaluation_dimension_list_name: str, ) -> tuple[ParallelSotopiaEnv, Agents, dict[str, Observation]]: # environment_profile = EnvironmentProfile.find().all()[0] # agent_profiles = AgentProfile.find().all()[:2] @@ -79,6 +84,12 @@ def get_env_agents( for idx, goal in enumerate(environment_profile.agent_goals): agent_list[idx].goal = goal + evaluation_dimensions: Type[BaseModel] = ( + EvaluationDimensionBuilder.select_existing_dimension_model_by_list_name( + list_name=evaluation_dimension_list_name + ) + ) + agents = Agents({agent.agent_name: agent for agent in agent_list}) env = ParallelSotopiaEnv( action_order="round-robin", @@ -89,7 +100,7 @@ def get_env_agents( terminal_evaluators=[ ReachGoalLLMEvaluator( evaluator_model, - EvaluationForTwoAgents[SotopiaDimensions], + EvaluationForTwoAgents[evaluation_dimensions], # type: ignore ), ], env_profile=environment_profile, @@ -124,9 +135,14 @@ def __init__( agent_ids: list[str], agent_models: list[str] = ["gpt-4o-mini", "gpt-4o-mini"], evaluator_model: str = "gpt-4o", + evaluation_dimension_list_name: str = "sotopia", ) -> None: self.env, self.agents, self.environment_messages = get_env_agents( - env_id, agent_ids, agent_models, evaluator_model + env_id, + agent_ids, + agent_models, + evaluator_model, + evaluation_dimension_list_name, ) self.messages: list[list[tuple[str, str, str]]] = [] self.messages.append( From e930d207db786d3901bbd78fee5390d73433fc65 Mon Sep 17 00:00:00 2001 From: XuhuiZhou Date: Sat, 28 Dec 2024 11:59:05 -0500 Subject: [PATCH 7/7] Implement modal API server with Redis integration and FastAPI setup - Added a new script for the modal API server that initializes a Redis instance. - Created a persistent volume for Redis data and included a function to download initial data if not present. - Configured a Docker image with necessary dependencies including Redis Stack and FastAPI. - Implemented a web API class that sets up and cleans up the Redis connection, ensuring readiness before serving requests. - Integrated the SotopiaFastAPI application within the modal framework. --- .../ui => scripts/modal}/modal_api_server.py | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) rename {sotopia/ui => scripts/modal}/modal_api_server.py (83%) diff --git a/sotopia/ui/modal_api_server.py b/scripts/modal/modal_api_server.py similarity index 83% rename from sotopia/ui/modal_api_server.py rename to scripts/modal/modal_api_server.py index 81c1c044b..30c8c8da8 100644 --- a/sotopia/ui/modal_api_server.py +++ b/scripts/modal/modal_api_server.py @@ -47,28 +47,9 @@ def initialize_redis_data() -> None: "apt-get install -y redis-stack-server", ) .pip_install( - "pydantic>=2.5.0,<3.0.0", - "aiohttp>=3.9.3,<4.0.0", - "rich>=13.8.1,<14.0.0", - "typer>=0.12.5", - "aiostream>=0.5.2", - "fastapi[all]", - "uvicorn", - "redis>=5.0.0", - "rq", - "lxml>=4.9.3,<6.0.0", - "openai>=1.11.0,<2.0.0", - "langchain>=0.2.5,<0.4.0", - "PettingZoo==1.24.3", - "redis-om>=0.3.0,<0.4.0", - "gin-config>=0.5.0,<0.6.0", - "absl-py>=2.0.0,<3.0.0", - "together>=0.2.4,<1.4.0", - "beartype>=0.14.0,<0.20.0", - "langchain-openai>=0.1.8,<0.2", - "hiredis>=3.0.0", - "aact", - "gin", + "sotopia>=0.1.2", + "fastapi>0.100", # TODO: remove this dependency after pypi release + "uvicorn", # TODO: remove this dependency after pypi release ) ) redis_volume = modal.Volume.from_name("sotopia-api", create_if_missing=True)