|
import uuid |
|
from typing import List, Dict |
|
from aiohttp import ClientSession |
|
from api.models import ChatRequest |
|
from api.helper import format_prompt |
|
from api.logger import logger |
|
from api.config import MODEL_MAPPING, EDITEA_API_ENDPOINT, EDITEA_HEADERS |
|
from fastapi import HTTPException |
|
|
|
|
|
class Editee: |
|
label = "Editee" |
|
url = "https://editee.com" |
|
api_endpoint = EDITEA_API_ENDPOINT |
|
working = True |
|
supports_stream = True |
|
supports_system_message = True |
|
supports_message_history = True |
|
|
|
default_model = 'claude' |
|
models = ['claude', 'gpt4', 'gemini', 'mistrallarge'] |
|
|
|
model_aliases = { |
|
"claude-3.5-sonnet": "claude", |
|
"gpt-4o": "gpt4", |
|
"gemini-pro": "gemini", |
|
"mistral-large": "mistrallarge", |
|
} |
|
|
|
@classmethod |
|
def get_model(cls, model: str) -> str: |
|
if model in cls.models: |
|
return model |
|
elif model in cls.model_aliases: |
|
return cls.model_aliases[model] |
|
else: |
|
return cls.default_model |
|
|
|
@classmethod |
|
async def create_async_generator( |
|
cls, |
|
model: str, |
|
messages: List[Dict[str, str]], |
|
proxy: str = None, |
|
**kwargs |
|
): |
|
model = cls.get_model(model) |
|
|
|
headers = EDITEA_HEADERS |
|
|
|
async with ClientSession(headers=headers) as session: |
|
prompt = format_prompt(messages) |
|
data = { |
|
"user_input": prompt, |
|
"context": " ", |
|
"template_id": "", |
|
"selected_model": model |
|
} |
|
try: |
|
async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: |
|
response.raise_for_status() |
|
if response.content_type == 'text/event-stream': |
|
async for line in response.content: |
|
yield line.decode('utf-8') |
|
else: |
|
response_data = await response.json() |
|
yield response_data['text'] |
|
except Exception as e: |
|
logger.error(f"Error in Editee API call: {e}") |
|
raise HTTPException(status_code=500, detail="Error in Editee API call") |
|
|
|
|
|
async def process_response(request: ChatRequest, stream: bool = False): |
|
try: |
|
model = MODEL_MAPPING.get(request.model, request.model) |
|
messages = [ |
|
{"role": message.role, "content": message.content} |
|
for message in request.messages |
|
] |
|
|
|
generator = Editee.create_async_generator( |
|
model=model, |
|
messages=messages, |
|
proxy=None |
|
) |
|
|
|
if stream: |
|
async def event_generator(): |
|
async for chunk in generator: |
|
yield f"data: {chunk}\n\n" |
|
return event_generator() |
|
else: |
|
full_response = "" |
|
async for chunk in generator: |
|
full_response += chunk |
|
|
|
return { |
|
"id": f"chatcmpl-{uuid.uuid4()}", |
|
"object": "chat.completion", |
|
"created": int(uuid.uuid1().time), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": {"role": "assistant", "content": full_response}, |
|
"finish_reason": "stop", |
|
} |
|
], |
|
"usage": None, |
|
} |
|
except Exception as e: |
|
logger.error(f"Error processing response: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|