Niansuh commited on
Commit
b27d93f
·
verified ·
1 Parent(s): 34226fa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -14
main.py CHANGED
@@ -3,8 +3,8 @@ from __future__ import annotations
3
  import re
4
  import random
5
  import string
6
- import json
7
  import uuid
 
8
  from aiohttp import ClientSession
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
@@ -41,7 +41,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
41
  supports_stream = True
42
  supports_system_message = True
43
  supports_message_history = True
44
-
45
  default_model = 'blackbox'
46
  models = [
47
  'blackbox',
@@ -66,13 +66,13 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
66
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
67
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
68
  }
69
-
70
  userSelectedModel = {
71
  "gpt-4o": "gpt-4o",
72
  "gemini-pro": "gemini-pro",
73
  'claude-sonnet-3.5': "claude-sonnet-3.5",
74
  }
75
-
76
  model_aliases = {
77
  "gemini-flash": "gemini-1.5-flash",
78
  "flux": "ImageGenerationLV45LJp",
@@ -124,16 +124,16 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
124
 
125
  if model in cls.userSelectedModel:
126
  prefix = f"@{cls.userSelectedModel[model]}"
127
- if messages and not messages[0]['content'].startswith(prefix):
128
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
129
-
130
  async with ClientSession(headers=headers) as session:
131
- if image is not None and messages:
132
  messages[-1]["data"] = {
133
  "fileText": image_name,
134
  "imageBase64": to_data_uri(image)
135
  }
136
-
137
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
138
 
139
  data = {
@@ -147,7 +147,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
147
  "userSelectedModel": None,
148
  "userSystemPrompt": None,
149
  "isMicMode": False,
150
- "maxTokens": 4096,
151
  "playgroundTopP": 0.9,
152
  "playgroundTemperature": 0.5,
153
  "isChromeExt": False,
@@ -166,7 +166,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
166
  data["trendingAgentMode"] = cls.trendingAgentMode[model]
167
  elif model in cls.userSelectedModel:
168
  data["userSelectedModel"] = cls.userSelectedModel[model]
169
-
170
  async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
171
  response.raise_for_status()
172
  if model == 'ImageGenerationLV45LJp':
@@ -180,7 +180,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
180
  else:
181
  async for chunk in response.content.iter_any():
182
  if chunk:
183
- decoded_chunk = chunk.decode(errors='ignore')
184
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
185
  if decoded_chunk.strip():
186
  yield decoded_chunk
@@ -195,7 +195,7 @@ class Message(BaseModel):
195
  class ChatRequest(BaseModel):
196
  model: str
197
  messages: List[Message]
198
- stream: Optional[bool] = False
199
 
200
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
201
  return {
@@ -215,6 +215,7 @@ def create_response(content: str, model: str, finish_reason: Optional[str] = Non
215
 
216
  @app.post("/niansuhai/v1/chat/completions")
217
  async def chat_completions(request: ChatRequest):
 
218
  valid_models = Blackbox.models + list(Blackbox.userSelectedModel.keys()) + list(Blackbox.model_aliases.keys())
219
  if request.model not in valid_models:
220
  raise HTTPException(status_code=400, detail=f"Invalid model name: {request.model}. Valid models are: {valid_models}")
@@ -222,7 +223,7 @@ async def chat_completions(request: ChatRequest):
222
  messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
223
 
224
  try:
225
- async_generator = await Blackbox.create_async_generator(
226
  model=request.model,
227
  messages=messages,
228
  image=None,
@@ -270,4 +271,4 @@ async def chat_completions(request: ChatRequest):
270
 
271
  @app.get("/niansuhai/v1/models")
272
  async def get_models():
273
- return {"models": Blackbox.models}
 
3
  import re
4
  import random
5
  import string
 
6
  import uuid
7
+ import json
8
  from aiohttp import ClientSession
9
  from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
 
41
  supports_stream = True
42
  supports_system_message = True
43
  supports_message_history = True
44
+
45
  default_model = 'blackbox'
46
  models = [
47
  'blackbox',
 
66
  'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
67
  'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"},
68
  }
69
+
70
  userSelectedModel = {
71
  "gpt-4o": "gpt-4o",
72
  "gemini-pro": "gemini-pro",
73
  'claude-sonnet-3.5': "claude-sonnet-3.5",
74
  }
75
+
76
  model_aliases = {
77
  "gemini-flash": "gemini-1.5-flash",
78
  "flux": "ImageGenerationLV45LJp",
 
124
 
125
  if model in cls.userSelectedModel:
126
  prefix = f"@{cls.userSelectedModel[model]}"
127
+ if not messages[0]['content'].startswith(prefix):
128
  messages[0]['content'] = f"{prefix} {messages[0]['content']}"
129
+
130
  async with ClientSession(headers=headers) as session:
131
+ if image is not None:
132
  messages[-1]["data"] = {
133
  "fileText": image_name,
134
  "imageBase64": to_data_uri(image)
135
  }
136
+
137
  random_id = ''.join(random.choices(string.ascii_letters + string.digits, k=7))
138
 
139
  data = {
 
147
  "userSelectedModel": None,
148
  "userSystemPrompt": None,
149
  "isMicMode": False,
150
+ "maxTokens": 8192,
151
  "playgroundTopP": 0.9,
152
  "playgroundTemperature": 0.5,
153
  "isChromeExt": False,
 
166
  data["trendingAgentMode"] = cls.trendingAgentMode[model]
167
  elif model in cls.userSelectedModel:
168
  data["userSelectedModel"] = cls.userSelectedModel[model]
169
+
170
  async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
171
  response.raise_for_status()
172
  if model == 'ImageGenerationLV45LJp':
 
180
  else:
181
  async for chunk in response.content.iter_any():
182
  if chunk:
183
+ decoded_chunk = chunk.decode(errors='ignore') # Handle decoding errors
184
  decoded_chunk = re.sub(r'\$@\$v=[^$]+\$@\$', '', decoded_chunk)
185
  if decoded_chunk.strip():
186
  yield decoded_chunk
 
195
  class ChatRequest(BaseModel):
196
  model: str
197
  messages: List[Message]
198
+ stream: Optional[bool] = False # Add this for streaming
199
 
200
  def create_response(content: str, model: str, finish_reason: Optional[str] = None) -> Dict[str, Any]:
201
  return {
 
215
 
216
  @app.post("/niansuhai/v1/chat/completions")
217
  async def chat_completions(request: ChatRequest):
218
+ # Validate the model
219
  valid_models = Blackbox.models + list(Blackbox.userSelectedModel.keys()) + list(Blackbox.model_aliases.keys())
220
  if request.model not in valid_models:
221
  raise HTTPException(status_code=400, detail=f"Invalid model name: {request.model}. Valid models are: {valid_models}")
 
223
  messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
224
 
225
  try:
226
+ async_generator = Blackbox.create_async_generator(
227
  model=request.model,
228
  messages=messages,
229
  image=None,
 
271
 
272
  @app.get("/niansuhai/v1/models")
273
  async def get_models():
274
+ return {"models": Blackbox.models}