Niansuh commited on
Commit
261bb88
·
verified ·
1 Parent(s): e4ee51f

Create utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +179 -0
api/utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import json
3
+ from typing import AsyncGenerator, Union
4
+ import uuid
5
+
6
+ import aiohttp
7
+ from fastapi import HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from api.config import GIZAI_API_ENDPOINT, GIZAI_BASE_URL
10
+ from api.models import ChatRequest, ImageResponseModel, ChatCompletionResponse
11
+ from api.logger import setup_logger
12
+
13
+ logger = setup_logger(__name__)
14
+
15
+ class GizAI:
16
+ # Chat models
17
+ default_model = 'chat-gemini-flash'
18
+ chat_models = [
19
+ default_model,
20
+ 'chat-gemini-pro',
21
+ 'chat-gpt4m',
22
+ 'chat-gpt4',
23
+ 'claude-sonnet',
24
+ 'claude-haiku',
25
+ 'llama-3-70b',
26
+ 'llama-3-8b',
27
+ 'mistral-large',
28
+ 'chat-o1-mini'
29
+ ]
30
+
31
+ # Image models
32
+ image_models = [
33
+ 'flux1',
34
+ 'sdxl',
35
+ 'sd',
36
+ 'sd35',
37
+ ]
38
+
39
+ models = [*chat_models, *image_models]
40
+
41
+ model_aliases = {
42
+ # Chat model aliases
43
+ "gemini-flash": "chat-gemini-flash",
44
+ "gemini-pro": "chat-gemini-pro",
45
+ "gpt-4o-mini": "chat-gpt4m",
46
+ "gpt-4o": "chat-gpt4",
47
+ "claude-3.5-sonnet": "claude-sonnet",
48
+ "claude-3-haiku": "claude-haiku",
49
+ "llama-3.1-70b": "llama-3-70b",
50
+ "llama-3.1-8b": "llama-3-8b",
51
+ "o1-mini": "chat-o1-mini",
52
+ # Image model aliases
53
+ "sd-1.5": "sd",
54
+ "sd-3.5": "sd35",
55
+ "flux-schnell": "flux1",
56
+ }
57
+
58
+ @classmethod
59
+ def get_model(cls, model: str) -> str:
60
+ if model in cls.models:
61
+ return model
62
+ elif model in cls.model_aliases:
63
+ return cls.model_aliases[model]
64
+ else:
65
+ return cls.default_model
66
+
67
+ @classmethod
68
+ def is_image_model(cls, model: str) -> bool:
69
+ return model in cls.image_models
70
+
71
+ async def process_gizai_response(request: ChatRequest, model: str) -> Union[AsyncGenerator[str, None], JSONResponse]:
72
+ async with aiohttp.ClientSession() as session:
73
+ if GizAI.is_image_model(model):
74
+ # Image generation
75
+ prompt = request.messages[-1].content if isinstance(request.messages[-1].content, str) else request.messages[-1].content[0].get("text", "")
76
+ data = {
77
+ "model": model,
78
+ "input": {
79
+ "width": "1024",
80
+ "height": "1024",
81
+ "steps": 4,
82
+ "output_format": "webp",
83
+ "batch_size": 1,
84
+ "mode": "plan",
85
+ "prompt": prompt
86
+ }
87
+ }
88
+ try:
89
+ async with session.post(
90
+ GIZAI_API_ENDPOINT,
91
+ headers={
92
+ 'Accept': 'application/json, text/plain, */*',
93
+ 'Accept-Language': 'en-US,en;q=0.9',
94
+ 'Cache-Control': 'no-cache',
95
+ 'Connection': 'keep-alive',
96
+ 'Content-Type': 'application/json',
97
+ 'Origin': 'https://app.giz.ai',
98
+ 'Pragma': 'no-cache',
99
+ 'Sec-Fetch-Dest': 'empty',
100
+ 'Sec-Fetch-Mode': 'cors',
101
+ 'Sec-Fetch-Site': 'same-origin',
102
+ 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
103
+ 'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"',
104
+ 'sec-ch-ua-mobile': '?0',
105
+ 'sec-ch-ua-platform': '"Linux"'
106
+ },
107
+ json=data
108
+ ) as response:
109
+ response.raise_for_status()
110
+ response_data = await response.json()
111
+ if response_data.get('status') == 'completed' and response_data.get('output'):
112
+ images = response_data['output']
113
+ return {"images": images, "alt": "Generated Image"}
114
+ else:
115
+ raise HTTPException(status_code=500, detail="Image generation failed.")
116
+ except aiohttp.ClientResponseError as e:
117
+ logger.error(f"HTTP error occurred: {e.status} - {e.message}")
118
+ raise HTTPException(status_code=e.status, detail=str(e))
119
+ except Exception as e:
120
+ logger.error(f"Unexpected error: {str(e)}")
121
+ raise HTTPException(status_code=500, detail=str(e))
122
+ else:
123
+ # Chat completion
124
+ messages_formatted = [
125
+ {
126
+ "type": "human",
127
+ "content": msg.content if isinstance(msg.content, str) else msg.content[0].get("text", "")
128
+ } for msg in request.messages
129
+ ]
130
+ data = {
131
+ "model": model,
132
+ "input": {
133
+ "messages": messages_formatted,
134
+ "mode": "plan"
135
+ },
136
+ "noStream": not request.stream
137
+ }
138
+ try:
139
+ async with session.post(
140
+ GIZAI_API_ENDPOINT,
141
+ headers={
142
+ 'Accept': 'application/json, text/plain, */*',
143
+ 'Accept-Language': 'en-US,en;q=0.9',
144
+ 'Cache-Control': 'no-cache',
145
+ 'Connection': 'keep-alive',
146
+ 'Content-Type': 'application/json',
147
+ 'Origin': 'https://app.giz.ai',
148
+ 'Pragma': 'no-cache',
149
+ 'Sec-Fetch-Dest': 'empty',
150
+ 'Sec-Fetch-Mode': 'cors',
151
+ 'Sec-Fetch-Site': 'same-origin',
152
+ 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
153
+ 'sec-ch-ua': '"Not?A_Brand";v="99", "Chromium";v="130"',
154
+ 'sec-ch-ua-mobile': '?0',
155
+ 'sec-ch-ua-platform': '"Linux"'
156
+ },
157
+ json=data
158
+ ) as response:
159
+ response.raise_for_status()
160
+ if request.stream:
161
+ # Handle streaming response
162
+ async def stream_response():
163
+ async for line in response.content:
164
+ if line:
165
+ decoded_line = line.decode('utf-8').strip()
166
+ if decoded_line.startswith("data:"):
167
+ content = decoded_line.replace("data: ", "")
168
+ yield f"data: {content}\n\n"
169
+ return stream_response()
170
+ else:
171
+ # Handle non-streaming response
172
+ result = await response.json()
173
+ return result.get('output', '')
174
+ except aiohttp.ClientResponseError as e:
175
+ logger.error(f"HTTP error occurred: {e.status} - {e.message}")
176
+ raise HTTPException(status_code=e.status, detail=str(e))
177
+ except Exception as e:
178
+ logger.error(f"Unexpected error: {str(e)}")
179
+ raise HTTPException(status_code=500, detail=str(e))