Update api/providers.py
Browse files- api/providers.py +50 -40
api/providers.py
CHANGED
@@ -5,10 +5,11 @@ from __future__ import annotations
|
|
5 |
import json
|
6 |
import uuid
|
7 |
from aiohttp import ClientSession, ClientTimeout, ClientResponseError
|
8 |
-
|
9 |
from typing import AsyncGenerator, List, Dict, Any
|
10 |
-
|
11 |
-
from api.
|
|
|
|
|
12 |
|
13 |
class AmigoChat:
|
14 |
url = "https://amigochat.io"
|
@@ -116,36 +117,41 @@ class AmigoChat:
|
|
116 |
}
|
117 |
|
118 |
timeout = ClientTimeout(total=300)
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
line
|
127 |
-
|
128 |
-
if line
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
149 |
else:
|
150 |
# Image generation
|
151 |
prompt = messages[-1]['content']
|
@@ -154,10 +160,14 @@ class AmigoChat:
|
|
154 |
"model": model,
|
155 |
"personaId": cls.get_persona_id(model)
|
156 |
}
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
5 |
import json
|
6 |
import uuid
|
7 |
from aiohttp import ClientSession, ClientTimeout, ClientResponseError
|
|
|
8 |
from typing import AsyncGenerator, List, Dict, Any
|
9 |
+
|
10 |
+
from api.logger import setup_logger
|
11 |
+
|
12 |
+
logger = setup_logger(__name__)
|
13 |
|
14 |
class AmigoChat:
|
15 |
url = "https://amigochat.io"
|
|
|
117 |
}
|
118 |
|
119 |
timeout = ClientTimeout(total=300)
|
120 |
+
try:
|
121 |
+
async with session.post(cls.chat_api_endpoint, json=data, proxy=proxy, timeout=timeout) as response:
|
122 |
+
if response.status not in (200, 201):
|
123 |
+
error_text = await response.text()
|
124 |
+
raise Exception(f"Error {response.status}: {error_text}")
|
125 |
+
|
126 |
+
if stream:
|
127 |
+
async for line in response.content:
|
128 |
+
line = line.decode('utf-8').strip()
|
129 |
+
if line.startswith('data: '):
|
130 |
+
if line == 'data: [DONE]':
|
131 |
+
break
|
132 |
+
try:
|
133 |
+
chunk = json.loads(line[6:])
|
134 |
+
if 'choices' in chunk and len(chunk['choices']) > 0:
|
135 |
+
choice = chunk['choices'][0]
|
136 |
+
if 'delta' in choice:
|
137 |
+
content = choice['delta'].get('content')
|
138 |
+
elif 'text' in choice:
|
139 |
+
content = choice['text']
|
140 |
+
else:
|
141 |
+
content = None
|
142 |
+
if content:
|
143 |
+
yield content
|
144 |
+
except json.JSONDecodeError:
|
145 |
+
pass
|
146 |
+
else:
|
147 |
+
response_data = await response.json()
|
148 |
+
if 'choices' in response_data and len(response_data['choices']) > 0:
|
149 |
+
content = response_data['choices'][0]['message']['content']
|
150 |
+
yield content
|
151 |
+
except Exception as e:
|
152 |
+
logger.error(f"Error during request: {e}")
|
153 |
+
raise
|
154 |
+
|
155 |
else:
|
156 |
# Image generation
|
157 |
prompt = messages[-1]['content']
|
|
|
160 |
"model": model,
|
161 |
"personaId": cls.get_persona_id(model)
|
162 |
}
|
163 |
+
try:
|
164 |
+
async with session.post(cls.image_api_endpoint, json=data, proxy=proxy) as response:
|
165 |
+
response.raise_for_status()
|
166 |
+
response_data = await response.json()
|
167 |
+
if "data" in response_data:
|
168 |
+
image_urls = [item["url"] for item in response_data["data"] if "url" in item]
|
169 |
+
if image_urls:
|
170 |
+
yield json.dumps({"images": image_urls, "prompt": prompt})
|
171 |
+
except Exception as e:
|
172 |
+
logger.error(f"Error during image generation: {e}")
|
173 |
+
raise
|