Update api/providers.py
Browse files- api/providers.py +29 -43
api/providers.py
CHANGED
@@ -4,8 +4,8 @@ from __future__ import annotations
|
|
4 |
|
5 |
import json
|
6 |
import uuid
|
7 |
-
from aiohttp import ClientSession, ClientTimeout
|
8 |
-
from typing import AsyncGenerator, List, Dict, Any
|
9 |
|
10 |
from api.logger import setup_logger
|
11 |
|
@@ -77,7 +77,7 @@ class AmigoChat:
|
|
77 |
messages: List[Dict[str, Any]],
|
78 |
stream: bool = False,
|
79 |
proxy: str = None,
|
80 |
-
) -> AsyncGenerator[str, None]:
|
81 |
model = cls.get_model(model)
|
82 |
device_uuid = str(uuid.uuid4())
|
83 |
|
@@ -103,7 +103,6 @@ class AmigoChat:
|
|
103 |
|
104 |
async with ClientSession(headers=headers) as session:
|
105 |
if model in cls.chat_models:
|
106 |
-
# Chat completion
|
107 |
data = {
|
108 |
"messages": messages,
|
109 |
"model": model,
|
@@ -124,50 +123,37 @@ class AmigoChat:
|
|
124 |
raise Exception(f"Error {response.status}: {error_text}")
|
125 |
|
126 |
if stream:
|
127 |
-
async
|
128 |
-
line
|
129 |
-
|
130 |
-
if line
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
146 |
else:
|
147 |
response_data = await response.json()
|
148 |
if 'choices' in response_data and len(response_data['choices']) > 0:
|
149 |
content = response_data['choices'][0]['message']['content']
|
150 |
-
|
|
|
|
|
151 |
except Exception as e:
|
152 |
logger.error(f"Error during request: {e}")
|
153 |
raise
|
154 |
-
|
155 |
else:
|
156 |
-
#
|
157 |
-
|
158 |
-
data = {
|
159 |
-
"prompt": prompt,
|
160 |
-
"model": model,
|
161 |
-
"personaId": cls.get_persona_id(model)
|
162 |
-
}
|
163 |
-
try:
|
164 |
-
async with session.post(cls.image_api_endpoint, json=data, proxy=proxy) as response:
|
165 |
-
response.raise_for_status()
|
166 |
-
response_data = await response.json()
|
167 |
-
if "data" in response_data:
|
168 |
-
image_urls = [item["url"] for item in response_data["data"] if "url" in item]
|
169 |
-
if image_urls:
|
170 |
-
yield json.dumps({"images": image_urls, "prompt": prompt})
|
171 |
-
except Exception as e:
|
172 |
-
logger.error(f"Error during image generation: {e}")
|
173 |
-
raise
|
|
|
4 |
|
5 |
import json
|
6 |
import uuid
|
7 |
+
from aiohttp import ClientSession, ClientTimeout
|
8 |
+
from typing import AsyncGenerator, List, Dict, Any, Union
|
9 |
|
10 |
from api.logger import setup_logger
|
11 |
|
|
|
77 |
messages: List[Dict[str, Any]],
|
78 |
stream: bool = False,
|
79 |
proxy: str = None,
|
80 |
+
) -> Union[AsyncGenerator[str, None], str]:
|
81 |
model = cls.get_model(model)
|
82 |
device_uuid = str(uuid.uuid4())
|
83 |
|
|
|
103 |
|
104 |
async with ClientSession(headers=headers) as session:
|
105 |
if model in cls.chat_models:
|
|
|
106 |
data = {
|
107 |
"messages": messages,
|
108 |
"model": model,
|
|
|
123 |
raise Exception(f"Error {response.status}: {error_text}")
|
124 |
|
125 |
if stream:
|
126 |
+
async def stream_content():
|
127 |
+
async for line in response.content:
|
128 |
+
line = line.decode('utf-8').strip()
|
129 |
+
if line.startswith('data: '):
|
130 |
+
if line == 'data: [DONE]':
|
131 |
+
break
|
132 |
+
try:
|
133 |
+
chunk = json.loads(line[6:])
|
134 |
+
if 'choices' in chunk and len(chunk['choices']) > 0:
|
135 |
+
choice = chunk['choices'][0]
|
136 |
+
if 'delta' in choice:
|
137 |
+
content = choice['delta'].get('content')
|
138 |
+
elif 'text' in choice:
|
139 |
+
content = choice['text']
|
140 |
+
else:
|
141 |
+
content = None
|
142 |
+
if content:
|
143 |
+
yield content
|
144 |
+
except json.JSONDecodeError:
|
145 |
+
pass
|
146 |
+
return stream_content()
|
147 |
else:
|
148 |
response_data = await response.json()
|
149 |
if 'choices' in response_data and len(response_data['choices']) > 0:
|
150 |
content = response_data['choices'][0]['message']['content']
|
151 |
+
return content
|
152 |
+
else:
|
153 |
+
return ""
|
154 |
except Exception as e:
|
155 |
logger.error(f"Error during request: {e}")
|
156 |
raise
|
|
|
157 |
else:
|
158 |
+
# Handle image models or other cases if necessary
|
159 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|