Niansuh commited on
Commit
e53cb1f
·
verified ·
1 Parent(s): 6a8c132

Update api/providers.py

Browse files
Files changed (1) hide show
  1. api/providers.py +29 -43
api/providers.py CHANGED
@@ -4,8 +4,8 @@ from __future__ import annotations
4
 
5
  import json
6
  import uuid
7
- from aiohttp import ClientSession, ClientTimeout, ClientResponseError
8
- from typing import AsyncGenerator, List, Dict, Any
9
 
10
  from api.logger import setup_logger
11
 
@@ -77,7 +77,7 @@ class AmigoChat:
77
  messages: List[Dict[str, Any]],
78
  stream: bool = False,
79
  proxy: str = None,
80
- ) -> AsyncGenerator[str, None]:
81
  model = cls.get_model(model)
82
  device_uuid = str(uuid.uuid4())
83
 
@@ -103,7 +103,6 @@ class AmigoChat:
103
 
104
  async with ClientSession(headers=headers) as session:
105
  if model in cls.chat_models:
106
- # Chat completion
107
  data = {
108
  "messages": messages,
109
  "model": model,
@@ -124,50 +123,37 @@ class AmigoChat:
124
  raise Exception(f"Error {response.status}: {error_text}")
125
 
126
  if stream:
127
- async for line in response.content:
128
- line = line.decode('utf-8').strip()
129
- if line.startswith('data: '):
130
- if line == 'data: [DONE]':
131
- break
132
- try:
133
- chunk = json.loads(line[6:])
134
- if 'choices' in chunk and len(chunk['choices']) > 0:
135
- choice = chunk['choices'][0]
136
- if 'delta' in choice:
137
- content = choice['delta'].get('content')
138
- elif 'text' in choice:
139
- content = choice['text']
140
- else:
141
- content = None
142
- if content:
143
- yield content
144
- except json.JSONDecodeError:
145
- pass
 
 
146
  else:
147
  response_data = await response.json()
148
  if 'choices' in response_data and len(response_data['choices']) > 0:
149
  content = response_data['choices'][0]['message']['content']
150
- yield content
 
 
151
  except Exception as e:
152
  logger.error(f"Error during request: {e}")
153
  raise
154
-
155
  else:
156
- # Image generation
157
- prompt = messages[-1]['content']
158
- data = {
159
- "prompt": prompt,
160
- "model": model,
161
- "personaId": cls.get_persona_id(model)
162
- }
163
- try:
164
- async with session.post(cls.image_api_endpoint, json=data, proxy=proxy) as response:
165
- response.raise_for_status()
166
- response_data = await response.json()
167
- if "data" in response_data:
168
- image_urls = [item["url"] for item in response_data["data"] if "url" in item]
169
- if image_urls:
170
- yield json.dumps({"images": image_urls, "prompt": prompt})
171
- except Exception as e:
172
- logger.error(f"Error during image generation: {e}")
173
- raise
 
4
 
5
  import json
6
  import uuid
7
+ from aiohttp import ClientSession, ClientTimeout
8
+ from typing import AsyncGenerator, List, Dict, Any, Union
9
 
10
  from api.logger import setup_logger
11
 
 
77
  messages: List[Dict[str, Any]],
78
  stream: bool = False,
79
  proxy: str = None,
80
+ ) -> Union[AsyncGenerator[str, None], str]:
81
  model = cls.get_model(model)
82
  device_uuid = str(uuid.uuid4())
83
 
 
103
 
104
  async with ClientSession(headers=headers) as session:
105
  if model in cls.chat_models:
 
106
  data = {
107
  "messages": messages,
108
  "model": model,
 
123
  raise Exception(f"Error {response.status}: {error_text}")
124
 
125
  if stream:
126
+ async def stream_content():
127
+ async for line in response.content:
128
+ line = line.decode('utf-8').strip()
129
+ if line.startswith('data: '):
130
+ if line == 'data: [DONE]':
131
+ break
132
+ try:
133
+ chunk = json.loads(line[6:])
134
+ if 'choices' in chunk and len(chunk['choices']) > 0:
135
+ choice = chunk['choices'][0]
136
+ if 'delta' in choice:
137
+ content = choice['delta'].get('content')
138
+ elif 'text' in choice:
139
+ content = choice['text']
140
+ else:
141
+ content = None
142
+ if content:
143
+ yield content
144
+ except json.JSONDecodeError:
145
+ pass
146
+ return stream_content()
147
  else:
148
  response_data = await response.json()
149
  if 'choices' in response_data and len(response_data['choices']) > 0:
150
  content = response_data['choices'][0]['message']['content']
151
+ return content
152
+ else:
153
+ return ""
154
  except Exception as e:
155
  logger.error(f"Error during request: {e}")
156
  raise
 
157
  else:
158
+ # Handle image models or other cases if necessary
159
+ pass