|
import json |
|
import uuid |
|
import asyncio |
|
from aiohttp import ClientSession, ClientTimeout, ClientResponseError |
|
from fastapi import HTTPException |
|
from api.logger import setup_logger |
|
from api.config import MODEL_MAPPING |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
AMIGOCHAT_URL = "https://amigochat.io/chat/" |
|
CHAT_API_ENDPOINT = "https://api.amigochat.io/v1/chat/completions" |
|
IMAGE_API_ENDPOINT = "https://api.amigochat.io/v1/images/generations" |
|
|
|
|
|
AMIGOCHAT_CHAT_MODELS = [ |
|
'gpt-4o', |
|
'gpt-4o-mini', |
|
'o1-preview', |
|
'o1-mini', |
|
'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo', |
|
'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo', |
|
'claude-3-sonnet-20240229', |
|
'gemini-1.5-pro', |
|
] |
|
|
|
AMIGOCHAT_IMAGE_MODELS = [ |
|
'flux-pro/v1.1', |
|
'flux-realism', |
|
'flux-pro', |
|
'dalle-e-3', |
|
] |
|
|
|
AMIGOCHAT_MODELS = AMIGOCHAT_CHAT_MODELS + AMIGOCHAT_IMAGE_MODELS |
|
|
|
AMIGOCHAT_MODEL_ALIASES = { |
|
"o1": "o1-preview", |
|
"llama-3.1-405b": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", |
|
"llama-3.2-90b": "meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo", |
|
"claude-3.5-sonnet": "claude-3-sonnet-20240229", |
|
"gemini-pro": "gemini-1.5-pro", |
|
"dalle-3": "dalle-e-3", |
|
} |
|
|
|
PERSONA_IDS = { |
|
'gpt-4o': "gpt", |
|
'gpt-4o-mini': "amigo", |
|
'o1-preview': "openai-o-one", |
|
'o1-mini': "openai-o-one-mini", |
|
'meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo': "llama-three-point-one", |
|
'meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo': "llama-3-2", |
|
'claude-3-sonnet-20240229': "claude", |
|
'gemini-1.5-pro': "gemini-1-5-pro", |
|
'flux-pro/v1.1': "flux-1-1-pro", |
|
'flux-realism': "flux-realism", |
|
'flux-pro': "flux-pro", |
|
'dalle-e-3': "dalle-three", |
|
} |
|
|
|
def get_amigochat_model(model: str) -> str: |
|
model = MODEL_MAPPING.get(model, model) |
|
if model in AMIGOCHAT_MODELS: |
|
return model |
|
elif model in AMIGOCHAT_MODEL_ALIASES: |
|
return AMIGOCHAT_MODEL_ALIASES[model] |
|
else: |
|
|
|
return 'gpt-4o-mini' |
|
|
|
def get_persona_id(model: str) -> str: |
|
return PERSONA_IDS.get(model, "amigo") |
|
|
|
def is_image_model(model: str) -> bool: |
|
return model in AMIGOCHAT_IMAGE_MODELS |
|
|
|
async def process_streaming_response(request_data): |
|
model = get_amigochat_model(request_data.get('model')) |
|
messages = request_data.get('messages') |
|
stream = request_data.get('stream', False) |
|
if not messages: |
|
raise HTTPException(status_code=400, detail="Messages are required") |
|
|
|
device_uuid = str(uuid.uuid4()) |
|
|
|
headers = { |
|
"accept": "*/*", |
|
"accept-language": "en-US,en;q=0.9", |
|
"authorization": "Bearer", |
|
"cache-control": "no-cache", |
|
"content-type": "application/json", |
|
"origin": AMIGOCHAT_URL, |
|
"pragma": "no-cache", |
|
"priority": "u=1, i", |
|
"referer": f"{AMIGOCHAT_URL}/", |
|
"sec-ch-ua": '"Chromium";v="129", "Not=A?Brand";v="8"', |
|
"sec-ch-ua-mobile": "?0", |
|
"sec-ch-ua-platform": '"Linux"', |
|
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", |
|
"x-device-language": "en-US", |
|
"x-device-platform": "web", |
|
"x-device-uuid": device_uuid, |
|
"x-device-version": "1.0.32" |
|
} |
|
|
|
async with ClientSession(headers=headers) as session: |
|
if is_image_model(model): |
|
|
|
response = await process_non_streaming_response(request_data) |
|
return iter([json.dumps(response)]) |
|
else: |
|
|
|
data = { |
|
"messages": [{"role": m["role"], "content": m["content"]} for m in messages], |
|
"model": model, |
|
"personaId": get_persona_id(model), |
|
"frequency_penalty": 0, |
|
"max_tokens": 4000, |
|
"presence_penalty": 0, |
|
"stream": True, |
|
"temperature": 0.5, |
|
"top_p": 0.95 |
|
} |
|
|
|
timeout = ClientTimeout(total=300) |
|
|
|
async def event_stream(): |
|
try: |
|
async with session.post(CHAT_API_ENDPOINT, json=data, timeout=timeout) as resp: |
|
if resp.status not in (200, 201): |
|
error_text = await resp.text() |
|
raise HTTPException(status_code=resp.status, detail=error_text) |
|
|
|
async for line in resp.content: |
|
line = line.decode('utf-8').strip() |
|
if line.startswith('data: '): |
|
if line == 'data: [DONE]': |
|
break |
|
try: |
|
chunk = json.loads(line[6:]) |
|
if 'choices' in chunk and len(chunk['choices']) > 0: |
|
choice = chunk['choices'][0] |
|
delta = choice.get('delta', {}) |
|
content = delta.get('content') |
|
if content: |
|
|
|
response_data = { |
|
"id": f"chatcmpl-{uuid.uuid4()}", |
|
"object": "chat.completion.chunk", |
|
"created": int(datetime.now().timestamp()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"delta": {"content": content}, |
|
"index": 0, |
|
"finish_reason": None, |
|
} |
|
], |
|
} |
|
yield f"data: {json.dumps(response_data)}\n\n" |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
yield "data: [DONE]\n\n" |
|
except Exception as e: |
|
logger.error(f"Error in streaming response: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
return event_stream() |
|
|
|
async def process_non_streaming_response(request_data): |
|
model = get_amigochat_model(request_data.get('model')) |
|
messages = request_data.get('messages') |
|
if not messages: |
|
raise HTTPException(status_code=400, detail="Messages are required") |
|
|
|
device_uuid = str(uuid.uuid4()) |
|
|
|
headers = { |
|
"accept": "*/*", |
|
"accept-language": "en-US,en;q=0.9", |
|
"authorization": "Bearer", |
|
"cache-control": "no-cache", |
|
"content-type": "application/json", |
|
"origin": AMIGOCHAT_URL, |
|
"pragma": "no-cache", |
|
"priority": "u=1, i", |
|
"referer": f"{AMIGOCHAT_URL}/", |
|
"sec-ch-ua": '"Chromium";v="129", "Not=A?Brand";v="8"', |
|
"sec-ch-ua-mobile": "?0", |
|
"sec-ch-ua-platform": '"Linux"', |
|
"user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", |
|
"x-device-language": "en-US", |
|
"x-device-platform": "web", |
|
"x-device-uuid": device_uuid, |
|
"x-device-version": "1.0.32" |
|
} |
|
|
|
async with ClientSession(headers=headers) as session: |
|
if is_image_model(model): |
|
|
|
prompt = messages[-1]['content'] |
|
data = { |
|
"prompt": prompt, |
|
"model": model, |
|
"personaId": get_persona_id(model) |
|
} |
|
try: |
|
async with session.post(IMAGE_API_ENDPOINT, json=data) as response: |
|
response.raise_for_status() |
|
response_data = await response.json() |
|
if "data" in response_data: |
|
image_urls = [] |
|
for item in response_data["data"]: |
|
if "url" in item: |
|
image_url = item["url"] |
|
image_urls.append(image_url) |
|
if image_urls: |
|
return { |
|
"id": f"imggen-{uuid.uuid4()}", |
|
"object": "image_generation", |
|
"created": int(datetime.now().timestamp()), |
|
"model": model, |
|
"data": image_urls, |
|
} |
|
raise HTTPException(status_code=500, detail="Image generation failed") |
|
except Exception as e: |
|
logger.error(f"Error in image generation: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
else: |
|
|
|
data = { |
|
"messages": [{"role": m["role"], "content": m["content"]} for m in messages], |
|
"model": model, |
|
"personaId": get_persona_id(model), |
|
"frequency_penalty": 0, |
|
"max_tokens": 4000, |
|
"presence_penalty": 0, |
|
"stream": False, |
|
"temperature": 0.5, |
|
"top_p": 0.95 |
|
} |
|
|
|
try: |
|
async with session.post(CHAT_API_ENDPOINT, json=data) as response: |
|
response.raise_for_status() |
|
response_data = await response.json() |
|
output = response_data.get('choices', [{}])[0].get('message', {}).get('content', '') |
|
return { |
|
"id": f"chatcmpl-{uuid.uuid4()}", |
|
"object": "chat.completion", |
|
"created": int(datetime.now().timestamp()), |
|
"model": model, |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": {"role": "assistant", "content": output}, |
|
"finish_reason": "stop", |
|
} |
|
], |
|
"usage": None, |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in chat completion: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|