euler314 commited on
Commit
c08b46a
Β·
verified Β·
1 Parent(s): f2ac9f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -83
app.py CHANGED
@@ -82,14 +82,12 @@ def prepare_api_params(messages, model_name):
82
  "model": model_name
83
  }
84
 
85
- # Add the appropriate token parameter based on model
86
- param_name = config["param_name"]
87
- if param_name == "max_completion_tokens":
88
- # Use the max_completion_tokens value from config
89
- api_params["max_completion_tokens"] = config["max_completion_tokens"]
90
- else:
91
- # Use the max_tokens value for other models
92
- api_params["max_tokens"] = config["max_tokens"]
93
 
94
  return api_params, config
95
 
@@ -329,8 +327,8 @@ def init_ai_models_direct():
329
 
330
  def suggest_code_completion(code_snippet, models):
331
  """Generate code completion using the AI model"""
332
- if not models or "client" not in models:
333
- st.error("AI models not properly initialized. Please use the Debug Connection section to test API connectivity.")
334
  return None
335
 
336
  try:
@@ -348,45 +346,61 @@ Here's the complete Manim code:
348
  """
349
 
350
  with st.spinner("AI is generating your animation code..."):
351
- from azure.ai.inference.models import UserMessage
352
-
353
- # Get the current model name
354
  model_name = models["model_name"]
355
 
356
- # Prepare API parameters based on model
357
- messages = [UserMessage(prompt)]
358
- api_params, config = prepare_api_params(messages, model_name)
359
 
360
- # Check if we need to specify API version
361
- if config["api_version"]:
362
- # If we need a specific API version, we need to create a new client with that version
363
- logger.info(f"Using API version {config['api_version']} for model {model_name}")
364
 
365
- # Get token from session state
366
  token = get_secret("github_token_api")
367
- if not token:
368
- st.error("GitHub token not found in secrets")
369
- return None
370
-
371
- # Import required modules for creating client with specific API version
372
- from azure.ai.inference import ChatCompletionsClient
373
- from azure.core.credentials import AzureKeyCredential
374
-
375
- # Create client with specific API version
376
- version_specific_client = ChatCompletionsClient(
377
- endpoint=models["endpoint"],
378
- credential=AzureKeyCredential(token),
379
- api_version=config["api_version"]
380
- )
381
 
382
- # Make the API call with the version-specific client
383
- response = version_specific_client.complete(**api_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  else:
385
- # Use the existing client
 
 
 
 
 
 
 
386
  response = models["client"].complete(**api_params)
387
-
388
- # Process the response
389
- completed_code = response.choices[0].message.content
390
 
391
  # Process the code
392
  if "```python" in completed_code:
@@ -2132,54 +2146,114 @@ class MyScene(Scene):
2132
  st.error("GitHub token not found in secrets")
2133
  st.stop()
2134
 
2135
- # Import required modules
2136
- import os
2137
- from azure.ai.inference import ChatCompletionsClient
2138
- from azure.ai.inference.models import SystemMessage, UserMessage
2139
- from azure.core.credentials import AzureKeyCredential
2140
-
2141
- # Define endpoint
2142
- endpoint = "https://models.inference.ai.azure.com"
2143
  model_name = st.session_state.custom_model
 
 
2144
 
2145
- # Prepare API parameters
2146
- messages = [UserMessage("Hello, this is a connection test.")]
2147
- api_params, config = prepare_api_params(messages, model_name)
2148
-
2149
- # Create client with appropriate API version
2150
- api_version = config.get("api_version")
2151
- if api_version:
2152
- client = ChatCompletionsClient(
2153
- endpoint=endpoint,
2154
- credential=AzureKeyCredential(token),
2155
- api_version=api_version
2156
- )
2157
- else:
2158
- client = ChatCompletionsClient(
2159
- endpoint=endpoint,
2160
- credential=AzureKeyCredential(token),
2161
  )
2162
-
2163
- # Test with the prepared parameters
2164
- response = client.complete(**api_params)
2165
-
2166
- # Check if response is valid
2167
- if response and response.choices and len(response.choices) > 0:
2168
- test_response = response.choices[0].message.content
2169
- st.success(f"βœ… Connection successful! Response: {test_response[:50]}...")
2170
 
2171
- # Save working connection to session state
2172
- st.session_state.ai_models = {
2173
- "client": client,
2174
- "model_name": model_name,
2175
- "endpoint": endpoint,
2176
- "last_loaded": datetime.now().isoformat()
 
2177
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2178
  else:
2179
- st.error("❌ API returned an empty response")
 
2180
  except ImportError as ie:
2181
  st.error(f"Module import error: {str(ie)}")
2182
- st.info("Try installing required packages: azure-ai-inference and azure-core")
2183
  except Exception as e:
2184
  st.error(f"❌ API test failed: {str(e)}")
2185
  import traceback
@@ -2215,12 +2289,13 @@ class MyScene(Scene):
2215
  <div class="model-card {'selected-model' if is_selected else ''}">
2216
  <h4>{model_name}</h4>
2217
  <div class="model-details">
2218
- <p>Max Tokens: {config.get('max_tokens', config.get('max_completion_tokens', 'Unknown')):,}</p>
 
2219
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
2220
  {warning_html}
2221
  </div>
2222
  </div>
2223
- """, unsafe_allow_html=True)
2224
 
2225
  # Button to select this model
2226
  button_label = "Selected βœ“" if is_selected else "Select Model"
 
82
  "model": model_name
83
  }
84
 
85
+ # Add the appropriate token parameter based on model's parameter name
86
+ token_param = config["param_name"]
87
+ token_value = config[token_param] # Get the actual value from the config
88
+
89
+ # Add the parameter to the API params
90
+ api_params[token_param] = token_value
 
 
91
 
92
  return api_params, config
93
 
 
327
 
328
  def suggest_code_completion(code_snippet, models):
329
  """Generate code completion using the AI model"""
330
+ if not models:
331
+ st.error("AI models not properly initialized.")
332
  return None
333
 
334
  try:
 
346
  """
347
 
348
  with st.spinner("AI is generating your animation code..."):
349
+ # Get the current model name and base URL
 
 
350
  model_name = models["model_name"]
351
 
352
+ # Convert message to the appropriate format based on model category
353
+ config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
354
+ category = config.get("category", "Other")
355
 
356
+ if category == "OpenAI":
357
+ # Import OpenAI client
358
+ from openai import OpenAI
 
359
 
360
+ # Get token
361
  token = get_secret("github_token_api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
+ # Create or get client
364
+ if "openai_client" not in models:
365
+ client = OpenAI(
366
+ base_url="https://models.github.ai/inference",
367
+ api_key=token
368
+ )
369
+ models["openai_client"] = client
370
+ else:
371
+ client = models["openai_client"]
372
+
373
+ # For OpenAI models, we need role-based messages
374
+ messages = [
375
+ {"role": "system", "content": "You are an expert in Manim animations."},
376
+ {"role": "user", "content": prompt}
377
+ ]
378
+
379
+ # Create params
380
+ params = {
381
+ "messages": messages,
382
+ "model": model_name
383
+ }
384
+
385
+ # Add token parameter
386
+ token_param = config["param_name"]
387
+ params[token_param] = config[token_param]
388
+
389
+ # Make API call
390
+ response = client.chat.completions.create(**params)
391
+ completed_code = response.choices[0].message.content
392
+
393
  else:
394
+ # Use Azure client
395
+ from azure.ai.inference.models import UserMessage
396
+
397
+ # Convert message format for Azure
398
+ messages = [UserMessage(prompt)]
399
+ api_params, _ = prepare_api_params(messages, model_name)
400
+
401
+ # Make API call with Azure client
402
  response = models["client"].complete(**api_params)
403
+ completed_code = response.choices[0].message.content
 
 
404
 
405
  # Process the code
406
  if "```python" in completed_code:
 
2146
  st.error("GitHub token not found in secrets")
2147
  st.stop()
2148
 
2149
+ # Get model details
 
 
 
 
 
 
 
2150
  model_name = st.session_state.custom_model
2151
+ config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2152
+ category = config.get("category", "Other")
2153
 
2154
+ if category == "OpenAI":
2155
+ # Use OpenAI client for GitHub AI models
2156
+ try:
2157
+ from openai import OpenAI
2158
+ except ImportError:
2159
+ st.error("OpenAI package not installed. Please run 'pip install openai'")
2160
+ st.stop()
2161
+
2162
+ # Create OpenAI client with GitHub AI endpoint
2163
+ client = OpenAI(
2164
+ base_url="https://models.github.ai/inference",
2165
+ api_key=token,
 
 
 
 
2166
  )
 
 
 
 
 
 
 
 
2167
 
2168
+ # Prepare parameters based on model configuration
2169
+ params = {
2170
+ "messages": [
2171
+ {"role": "system", "content": "You are a helpful assistant."},
2172
+ {"role": "user", "content": "Hello, this is a connection test."}
2173
+ ],
2174
+ "model": model_name
2175
  }
2176
+
2177
+ # Add appropriate token parameter
2178
+ token_param = config["param_name"]
2179
+ params[token_param] = config[token_param]
2180
+
2181
+ # Make API call
2182
+ response = client.chat.completions.create(**params)
2183
+
2184
+ # Check if response is valid
2185
+ if response and response.choices and len(response.choices) > 0:
2186
+ test_response = response.choices[0].message.content
2187
+ st.success(f"βœ… Connection successful! Response: {test_response[:50]}...")
2188
+
2189
+ # Save working connection to session state
2190
+ st.session_state.ai_models = {
2191
+ "openai_client": client,
2192
+ "model_name": model_name,
2193
+ "endpoint": "https://models.github.ai/inference",
2194
+ "last_loaded": datetime.now().isoformat(),
2195
+ "category": category
2196
+ }
2197
+ else:
2198
+ st.error("❌ API returned an empty response")
2199
+
2200
+ elif category == "Azure" or category in ["DeepSeek", "Meta", "Microsoft", "Mistral", "Other"]:
2201
+ # Use Azure client for Azure API models
2202
+ try:
2203
+ from azure.ai.inference import ChatCompletionsClient
2204
+ from azure.ai.inference.models import SystemMessage, UserMessage
2205
+ from azure.core.credentials import AzureKeyCredential
2206
+ except ImportError:
2207
+ st.error("Azure AI packages not installed. Please run 'pip install azure-ai-inference azure-core'")
2208
+ st.stop()
2209
+
2210
+ # Define endpoint
2211
+ endpoint = "https://models.inference.ai.azure.com"
2212
+
2213
+ # Prepare API parameters
2214
+ messages = [UserMessage("Hello, this is a connection test.")]
2215
+ api_params, config = prepare_api_params(messages, model_name)
2216
+
2217
+ # Create client with appropriate API version
2218
+ api_version = config.get("api_version")
2219
+ if api_version:
2220
+ client = ChatCompletionsClient(
2221
+ endpoint=endpoint,
2222
+ credential=AzureKeyCredential(token),
2223
+ api_version=api_version
2224
+ )
2225
+ else:
2226
+ client = ChatCompletionsClient(
2227
+ endpoint=endpoint,
2228
+ credential=AzureKeyCredential(token),
2229
+ )
2230
+
2231
+ # Test with the prepared parameters
2232
+ response = client.complete(**api_params)
2233
+
2234
+ # Check if response is valid
2235
+ if response and response.choices and len(response.choices) > 0:
2236
+ test_response = response.choices[0].message.content
2237
+ st.success(f"βœ… Connection successful! Response: {test_response[:50]}...")
2238
+
2239
+ # Save working connection to session state
2240
+ st.session_state.ai_models = {
2241
+ "client": client,
2242
+ "model_name": model_name,
2243
+ "endpoint": endpoint,
2244
+ "last_loaded": datetime.now().isoformat(),
2245
+ "category": category,
2246
+ "api_version": api_version
2247
+ }
2248
+ else:
2249
+ st.error("❌ API returned an empty response")
2250
+
2251
  else:
2252
+ st.error(f"Unsupported model category: {category}")
2253
+
2254
  except ImportError as ie:
2255
  st.error(f"Module import error: {str(ie)}")
2256
+ st.info("Try installing required packages: openai, azure-ai-inference and azure-core")
2257
  except Exception as e:
2258
  st.error(f"❌ API test failed: {str(e)}")
2259
  import traceback
 
2289
  <div class="model-card {'selected-model' if is_selected else ''}">
2290
  <h4>{model_name}</h4>
2291
  <div class="model-details">
2292
+ <p>Max Tokens: {config.get(config['param_name'], 'Unknown')}</p>
2293
+ <p>Category: {config['category']}</p>
2294
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
2295
  {warning_html}
2296
  </div>
2297
  </div>
2298
+ """, unsafe_allow_html=True)
2299
 
2300
  # Button to select this model
2301
  button_label = "Selected βœ“" if is_selected else "Select Model"