|
import argparse |
|
import json |
|
import logging |
|
import time |
|
|
|
import uvicorn |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.requests import Request |
|
|
|
from lagent.schema import AgentMessage |
|
from lagent.utils import load_class_from_string |
|
|
|
|
|
class AgentAPIServer: |
|
|
|
def __init__(self, |
|
config: dict, |
|
host: str = '127.0.0.1', |
|
port: int = 8090): |
|
self.app = FastAPI(docs_url='/') |
|
self.app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=['*'], |
|
allow_credentials=True, |
|
allow_methods=['*'], |
|
allow_headers=['*'], |
|
) |
|
cls_name = config.pop('type') |
|
python_path = config.pop('python_path', None) |
|
cls_name = load_class_from_string(cls_name, python_path) if isinstance( |
|
cls_name, str) else cls_name |
|
self.agent = cls_name(**config) |
|
self.setup_routes() |
|
self.run(host, port) |
|
|
|
def setup_routes(self): |
|
|
|
def heartbeat(): |
|
return {'status': 'success', 'timestamp': time.time()} |
|
|
|
async def process_message(request: Request): |
|
try: |
|
body = await request.json() |
|
message = [ |
|
m if isinstance(m, str) else AgentMessage.model_validate(m) |
|
for m in body.pop('message') |
|
] |
|
result = await self.agent(*message, **body) |
|
return result |
|
except Exception as e: |
|
logging.error(f'Error processing message: {str(e)}') |
|
raise HTTPException( |
|
status_code=500, detail='Internal Server Error') |
|
|
|
def get_memory(session_id: int = 0): |
|
try: |
|
result = self.agent.state_dict(session_id) |
|
return result |
|
except KeyError: |
|
raise HTTPException( |
|
status_code=404, detail="Session ID not found") |
|
except Exception as e: |
|
logging.error(f'Error processing message: {str(e)}') |
|
raise HTTPException( |
|
status_code=500, detail='Internal Server Error') |
|
|
|
self.app.add_api_route('/health_check', heartbeat, methods=['GET']) |
|
self.app.add_api_route( |
|
'/chat_completion', process_message, methods=['POST']) |
|
self.app.add_api_route( |
|
'/memory/{session_id}', get_memory, methods=['GET']) |
|
|
|
def run(self, host='127.0.0.1', port=8090): |
|
logging.info(f'Starting server at {host}:{port}') |
|
uvicorn.run(self.app, host=host, port=port) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Async Agent API Server') |
|
parser.add_argument('--host', type=str, default='127.0.0.1') |
|
parser.add_argument('--port', type=int, default=8090) |
|
parser.add_argument( |
|
'--config', |
|
type=json.loads, |
|
required=True, |
|
help='JSON configuration for the agent') |
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
logging.basicConfig(level=logging.INFO) |
|
args = parse_args() |
|
AgentAPIServer(args.config, host=args.host, port=args.port) |
|
|