bibibi12345 commited on
Commit
c096220
·
1 Parent(s): f3fdd7d
Files changed (1) hide show
  1. app/direct_vertex_client.py +125 -25
app/direct_vertex_client.py CHANGED
@@ -111,21 +111,115 @@ class DirectVertexClient:
111
  print(f"ERROR: Failed to discover project ID: {e}")
112
  raise
113
 
114
- def _convert_sdk_to_dict(self, obj: Any) -> Any:
115
- """Convert SDK objects to dictionaries for JSON serialization"""
116
- if hasattr(obj, '__dict__'):
117
- # Handle SDK objects with __dict__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  result = {}
119
  for key, value in obj.__dict__.items():
120
- if not key.startswith('_'): # Skip private attributes
121
- result[key] = self._convert_sdk_to_dict(value)
122
  return result
123
- elif isinstance(obj, list):
124
- return [self._convert_sdk_to_dict(item) for item in obj]
125
- elif isinstance(obj, dict):
126
- return {key: self._convert_sdk_to_dict(value) for key, value in obj.items()}
127
  else:
128
- # Return primitive types as-is
129
  return obj
130
 
131
  async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
@@ -139,21 +233,24 @@ class DirectVertexClient:
139
  endpoint = "streamGenerateContent" if stream else "generateContent"
140
  url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
141
 
142
- # Convert SDK objects to dictionaries for JSON serialization
143
- # The contents might be SDK Content objects that need conversion
144
  payload = {
145
- "contents": self._convert_sdk_to_dict(contents)
146
  }
147
 
148
- # Extract specific config sections and convert SDK objects
149
  if "system_instruction" in config:
150
- payload["systemInstruction"] = self._convert_sdk_to_dict(config["system_instruction"])
 
 
 
 
151
 
152
  if "safety_settings" in config:
153
- payload["safetySettings"] = self._convert_sdk_to_dict(config["safety_settings"])
154
 
155
  if "tools" in config:
156
- payload["tools"] = self._convert_sdk_to_dict(config["tools"])
157
 
158
  # All other config goes under generationConfig
159
  generation_config = {}
@@ -214,21 +311,24 @@ class DirectVertexClient:
214
  # Build URL for streaming
215
  url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
216
 
217
- # Convert SDK objects to dictionaries for JSON serialization
218
- # The contents might be SDK Content objects that need conversion
219
  payload = {
220
- "contents": self._convert_sdk_to_dict(contents)
221
  }
222
 
223
- # Extract specific config sections and convert SDK objects
224
  if "system_instruction" in config:
225
- payload["systemInstruction"] = self._convert_sdk_to_dict(config["system_instruction"])
 
 
 
 
226
 
227
  if "safety_settings" in config:
228
- payload["safetySettings"] = self._convert_sdk_to_dict(config["safety_settings"])
229
 
230
  if "tools" in config:
231
- payload["tools"] = self._convert_sdk_to_dict(config["tools"])
232
 
233
  # All other config goes under generationConfig
234
  generation_config = {}
 
111
  print(f"ERROR: Failed to discover project ID: {e}")
112
  raise
113
 
114
+ def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]:
115
+ """Convert SDK Content objects to REST API format"""
116
+ if isinstance(contents, list):
117
+ return [self._convert_content_item(item) for item in contents]
118
+ else:
119
+ return [self._convert_content_item(contents)]
120
+
121
+ def _convert_content_item(self, content: Any) -> Dict[str, Any]:
122
+ """Convert a single content item to REST API format"""
123
+ if isinstance(content, dict):
124
+ return content
125
+
126
+ # Handle SDK Content objects
127
+ result = {}
128
+ if hasattr(content, 'role'):
129
+ result['role'] = content.role
130
+ if hasattr(content, 'parts'):
131
+ result['parts'] = []
132
+ for part in content.parts:
133
+ if isinstance(part, dict):
134
+ result['parts'].append(part)
135
+ elif hasattr(part, 'text'):
136
+ result['parts'].append({'text': part.text})
137
+ elif hasattr(part, 'inline_data'):
138
+ result['parts'].append({
139
+ 'inline_data': {
140
+ 'mime_type': part.inline_data.mime_type,
141
+ 'data': part.inline_data.data
142
+ }
143
+ })
144
+ return result
145
+
146
+ def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]:
147
+ """Convert SDK SafetySetting objects to REST API format"""
148
+ if not safety_settings:
149
+ return []
150
+
151
+ result = []
152
+ for setting in safety_settings:
153
+ if isinstance(setting, dict):
154
+ result.append(setting)
155
+ elif hasattr(setting, 'category') and hasattr(setting, 'threshold'):
156
+ # Convert SDK SafetySetting to dict
157
+ result.append({
158
+ 'category': setting.category,
159
+ 'threshold': setting.threshold
160
+ })
161
+ return result
162
+
163
+ def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]:
164
+ """Convert SDK Tool objects to REST API format"""
165
+ if not tools:
166
+ return []
167
+
168
+ result = []
169
+ for tool in tools:
170
+ if isinstance(tool, dict):
171
+ result.append(tool)
172
+ else:
173
+ # Convert SDK Tool object to dict
174
+ result.append(self._convert_tool_item(tool))
175
+ return result
176
+
177
+ def _convert_tool_item(self, tool: Any) -> Dict[str, Any]:
178
+ """Convert a single tool item to REST API format"""
179
+ if isinstance(tool, dict):
180
+ return tool
181
+
182
+ tool_dict = {}
183
+
184
+ # Convert all non-private attributes
185
+ if hasattr(tool, '__dict__'):
186
+ for attr_name, attr_value in tool.__dict__.items():
187
+ if not attr_name.startswith('_'):
188
+ # Convert attribute names from snake_case to camelCase for REST API
189
+ rest_api_name = self._to_camel_case(attr_name)
190
+
191
+ # Special handling for known types
192
+ if attr_name == 'google_search' and attr_value is not None:
193
+ tool_dict[rest_api_name] = {} # GoogleSearch is empty object in REST
194
+ elif attr_name == 'function_declarations' and attr_value is not None:
195
+ tool_dict[rest_api_name] = attr_value
196
+ elif attr_value is not None:
197
+ # Recursively convert any other SDK objects
198
+ tool_dict[rest_api_name] = self._convert_sdk_object(attr_value)
199
+
200
+ return tool_dict
201
+
202
+ def _to_camel_case(self, snake_str: str) -> str:
203
+ """Convert snake_case to camelCase"""
204
+ components = snake_str.split('_')
205
+ return components[0] + ''.join(x.title() for x in components[1:])
206
+
207
+ def _convert_sdk_object(self, obj: Any) -> Any:
208
+ """Generic SDK object converter"""
209
+ if isinstance(obj, (str, int, float, bool, type(None))):
210
+ return obj
211
+ elif isinstance(obj, dict):
212
+ return {k: self._convert_sdk_object(v) for k, v in obj.items()}
213
+ elif isinstance(obj, list):
214
+ return [self._convert_sdk_object(item) for item in obj]
215
+ elif hasattr(obj, '__dict__'):
216
+ # Convert SDK object to dict
217
  result = {}
218
  for key, value in obj.__dict__.items():
219
+ if not key.startswith('_'):
220
+ result[self._to_camel_case(key)] = self._convert_sdk_object(value)
221
  return result
 
 
 
 
222
  else:
 
223
  return obj
224
 
225
  async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
 
233
  endpoint = "streamGenerateContent" if stream else "generateContent"
234
  url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
235
 
236
+ # Convert contents to REST API format
 
237
  payload = {
238
+ "contents": self._convert_contents(contents)
239
  }
240
 
241
+ # Extract specific config sections
242
  if "system_instruction" in config:
243
+ # System instruction should be a content object
244
+ if isinstance(config["system_instruction"], dict):
245
+ payload["systemInstruction"] = config["system_instruction"]
246
+ else:
247
+ payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
248
 
249
  if "safety_settings" in config:
250
+ payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
251
 
252
  if "tools" in config:
253
+ payload["tools"] = self._convert_tools(config["tools"])
254
 
255
  # All other config goes under generationConfig
256
  generation_config = {}
 
311
  # Build URL for streaming
312
  url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
313
 
314
+ # Convert contents to REST API format
 
315
  payload = {
316
+ "contents": self._convert_contents(contents)
317
  }
318
 
319
+ # Extract specific config sections
320
  if "system_instruction" in config:
321
+ # System instruction should be a content object
322
+ if isinstance(config["system_instruction"], dict):
323
+ payload["systemInstruction"] = config["system_instruction"]
324
+ else:
325
+ payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
326
 
327
  if "safety_settings" in config:
328
+ payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
329
 
330
  if "tools" in config:
331
+ payload["tools"] = self._convert_tools(config["tools"])
332
 
333
  # All other config goes under generationConfig
334
  generation_config = {}