Spaces:
Running
Running
Commit
·
c096220
1
Parent(s):
f3fdd7d
bug fix
Browse files- app/direct_vertex_client.py +125 -25
app/direct_vertex_client.py
CHANGED
@@ -111,21 +111,115 @@ class DirectVertexClient:
|
|
111 |
print(f"ERROR: Failed to discover project ID: {e}")
|
112 |
raise
|
113 |
|
114 |
-
def
|
115 |
-
"""Convert SDK objects to
|
116 |
-
if
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
result = {}
|
119 |
for key, value in obj.__dict__.items():
|
120 |
-
if not key.startswith('_'):
|
121 |
-
result[key] = self.
|
122 |
return result
|
123 |
-
elif isinstance(obj, list):
|
124 |
-
return [self._convert_sdk_to_dict(item) for item in obj]
|
125 |
-
elif isinstance(obj, dict):
|
126 |
-
return {key: self._convert_sdk_to_dict(value) for key, value in obj.items()}
|
127 |
else:
|
128 |
-
# Return primitive types as-is
|
129 |
return obj
|
130 |
|
131 |
async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
|
@@ -139,21 +233,24 @@ class DirectVertexClient:
|
|
139 |
endpoint = "streamGenerateContent" if stream else "generateContent"
|
140 |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
|
141 |
|
142 |
-
# Convert
|
143 |
-
# The contents might be SDK Content objects that need conversion
|
144 |
payload = {
|
145 |
-
"contents": self.
|
146 |
}
|
147 |
|
148 |
-
# Extract specific config sections
|
149 |
if "system_instruction" in config:
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
if "safety_settings" in config:
|
153 |
-
payload["safetySettings"] = self.
|
154 |
|
155 |
if "tools" in config:
|
156 |
-
payload["tools"] = self.
|
157 |
|
158 |
# All other config goes under generationConfig
|
159 |
generation_config = {}
|
@@ -214,21 +311,24 @@ class DirectVertexClient:
|
|
214 |
# Build URL for streaming
|
215 |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
|
216 |
|
217 |
-
# Convert
|
218 |
-
# The contents might be SDK Content objects that need conversion
|
219 |
payload = {
|
220 |
-
"contents": self.
|
221 |
}
|
222 |
|
223 |
-
# Extract specific config sections
|
224 |
if "system_instruction" in config:
|
225 |
-
|
|
|
|
|
|
|
|
|
226 |
|
227 |
if "safety_settings" in config:
|
228 |
-
payload["safetySettings"] = self.
|
229 |
|
230 |
if "tools" in config:
|
231 |
-
payload["tools"] = self.
|
232 |
|
233 |
# All other config goes under generationConfig
|
234 |
generation_config = {}
|
|
|
111 |
print(f"ERROR: Failed to discover project ID: {e}")
|
112 |
raise
|
113 |
|
114 |
+
def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]:
|
115 |
+
"""Convert SDK Content objects to REST API format"""
|
116 |
+
if isinstance(contents, list):
|
117 |
+
return [self._convert_content_item(item) for item in contents]
|
118 |
+
else:
|
119 |
+
return [self._convert_content_item(contents)]
|
120 |
+
|
121 |
+
def _convert_content_item(self, content: Any) -> Dict[str, Any]:
|
122 |
+
"""Convert a single content item to REST API format"""
|
123 |
+
if isinstance(content, dict):
|
124 |
+
return content
|
125 |
+
|
126 |
+
# Handle SDK Content objects
|
127 |
+
result = {}
|
128 |
+
if hasattr(content, 'role'):
|
129 |
+
result['role'] = content.role
|
130 |
+
if hasattr(content, 'parts'):
|
131 |
+
result['parts'] = []
|
132 |
+
for part in content.parts:
|
133 |
+
if isinstance(part, dict):
|
134 |
+
result['parts'].append(part)
|
135 |
+
elif hasattr(part, 'text'):
|
136 |
+
result['parts'].append({'text': part.text})
|
137 |
+
elif hasattr(part, 'inline_data'):
|
138 |
+
result['parts'].append({
|
139 |
+
'inline_data': {
|
140 |
+
'mime_type': part.inline_data.mime_type,
|
141 |
+
'data': part.inline_data.data
|
142 |
+
}
|
143 |
+
})
|
144 |
+
return result
|
145 |
+
|
146 |
+
def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]:
|
147 |
+
"""Convert SDK SafetySetting objects to REST API format"""
|
148 |
+
if not safety_settings:
|
149 |
+
return []
|
150 |
+
|
151 |
+
result = []
|
152 |
+
for setting in safety_settings:
|
153 |
+
if isinstance(setting, dict):
|
154 |
+
result.append(setting)
|
155 |
+
elif hasattr(setting, 'category') and hasattr(setting, 'threshold'):
|
156 |
+
# Convert SDK SafetySetting to dict
|
157 |
+
result.append({
|
158 |
+
'category': setting.category,
|
159 |
+
'threshold': setting.threshold
|
160 |
+
})
|
161 |
+
return result
|
162 |
+
|
163 |
+
def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]:
|
164 |
+
"""Convert SDK Tool objects to REST API format"""
|
165 |
+
if not tools:
|
166 |
+
return []
|
167 |
+
|
168 |
+
result = []
|
169 |
+
for tool in tools:
|
170 |
+
if isinstance(tool, dict):
|
171 |
+
result.append(tool)
|
172 |
+
else:
|
173 |
+
# Convert SDK Tool object to dict
|
174 |
+
result.append(self._convert_tool_item(tool))
|
175 |
+
return result
|
176 |
+
|
177 |
+
def _convert_tool_item(self, tool: Any) -> Dict[str, Any]:
|
178 |
+
"""Convert a single tool item to REST API format"""
|
179 |
+
if isinstance(tool, dict):
|
180 |
+
return tool
|
181 |
+
|
182 |
+
tool_dict = {}
|
183 |
+
|
184 |
+
# Convert all non-private attributes
|
185 |
+
if hasattr(tool, '__dict__'):
|
186 |
+
for attr_name, attr_value in tool.__dict__.items():
|
187 |
+
if not attr_name.startswith('_'):
|
188 |
+
# Convert attribute names from snake_case to camelCase for REST API
|
189 |
+
rest_api_name = self._to_camel_case(attr_name)
|
190 |
+
|
191 |
+
# Special handling for known types
|
192 |
+
if attr_name == 'google_search' and attr_value is not None:
|
193 |
+
tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST
|
194 |
+
elif attr_name == 'function_declarations' and attr_value is not None:
|
195 |
+
tool_dict[rest_api_name] = attr_value
|
196 |
+
elif attr_value is not None:
|
197 |
+
# Recursively convert any other SDK objects
|
198 |
+
tool_dict[rest_api_name] = self._convert_sdk_object(attr_value)
|
199 |
+
|
200 |
+
return tool_dict
|
201 |
+
|
202 |
+
def _to_camel_case(self, snake_str: str) -> str:
|
203 |
+
"""Convert snake_case to camelCase"""
|
204 |
+
components = snake_str.split('_')
|
205 |
+
return components[0] + ''.join(x.title() for x in components[1:])
|
206 |
+
|
207 |
+
def _convert_sdk_object(self, obj: Any) -> Any:
|
208 |
+
"""Generic SDK object converter"""
|
209 |
+
if isinstance(obj, (str, int, float, bool, type(None))):
|
210 |
+
return obj
|
211 |
+
elif isinstance(obj, dict):
|
212 |
+
return {k: self._convert_sdk_object(v) for k, v in obj.items()}
|
213 |
+
elif isinstance(obj, list):
|
214 |
+
return [self._convert_sdk_object(item) for item in obj]
|
215 |
+
elif hasattr(obj, '__dict__'):
|
216 |
+
# Convert SDK object to dict
|
217 |
result = {}
|
218 |
for key, value in obj.__dict__.items():
|
219 |
+
if not key.startswith('_'):
|
220 |
+
result[self._to_camel_case(key)] = self._convert_sdk_object(value)
|
221 |
return result
|
|
|
|
|
|
|
|
|
222 |
else:
|
|
|
223 |
return obj
|
224 |
|
225 |
async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
|
|
|
233 |
endpoint = "streamGenerateContent" if stream else "generateContent"
|
234 |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
|
235 |
|
236 |
+
# Convert contents to REST API format
|
|
|
237 |
payload = {
|
238 |
+
"contents": self._convert_contents(contents)
|
239 |
}
|
240 |
|
241 |
+
# Extract specific config sections
|
242 |
if "system_instruction" in config:
|
243 |
+
# System instruction should be a content object
|
244 |
+
if isinstance(config["system_instruction"], dict):
|
245 |
+
payload["systemInstruction"] = config["system_instruction"]
|
246 |
+
else:
|
247 |
+
payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
|
248 |
|
249 |
if "safety_settings" in config:
|
250 |
+
payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
|
251 |
|
252 |
if "tools" in config:
|
253 |
+
payload["tools"] = self._convert_tools(config["tools"])
|
254 |
|
255 |
# All other config goes under generationConfig
|
256 |
generation_config = {}
|
|
|
311 |
# Build URL for streaming
|
312 |
url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
|
313 |
|
314 |
+
# Convert contents to REST API format
|
|
|
315 |
payload = {
|
316 |
+
"contents": self._convert_contents(contents)
|
317 |
}
|
318 |
|
319 |
+
# Extract specific config sections
|
320 |
if "system_instruction" in config:
|
321 |
+
# System instruction should be a content object
|
322 |
+
if isinstance(config["system_instruction"], dict):
|
323 |
+
payload["systemInstruction"] = config["system_instruction"]
|
324 |
+
else:
|
325 |
+
payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
|
326 |
|
327 |
if "safety_settings" in config:
|
328 |
+
payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
|
329 |
|
330 |
if "tools" in config:
|
331 |
+
payload["tools"] = self._convert_tools(config["tools"])
|
332 |
|
333 |
# All other config goes under generationConfig
|
334 |
generation_config = {}
|