euler314 commited on
Commit
c5d0ad7
·
verified ·
1 Parent(s): a1879ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -68
app.py CHANGED
@@ -2141,6 +2141,12 @@ class MyScene(Scene):
2141
  # Get model configuration
2142
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2143
 
 
 
 
 
 
 
2144
  # Create client with appropriate API version
2145
  api_version = config.get("api_version")
2146
  if api_version:
@@ -2156,16 +2162,11 @@ class MyScene(Scene):
2156
  )
2157
 
2158
  # Test with a simple prompt
2159
- api_params = {
2160
- "messages": [UserMessage("Hello, this is a connection test.")],
2161
- "model": model_name
2162
- }
2163
-
2164
- # Use appropriate parameter name
2165
- api_params[config["param_name"]] = 100 # Just enough for a short response
2166
-
2167
- # Make the API call
2168
- response = client.complete(**api_params)
2169
 
2170
  # Check if response is valid
2171
  if response and response.choices and len(response.choices) > 0:
@@ -2210,14 +2211,18 @@ class MyScene(Scene):
2210
  for model_name in sorted(model_categories[category]):
2211
  config = MODEL_CONFIGS[model_name]
2212
  is_selected = model_name == st.session_state.custom_model
 
2213
 
2214
  # Create styled card for each model
 
 
2215
  st.markdown(f"""
2216
  <div class="model-card {'selected-model' if is_selected else ''}">
2217
  <h4>{model_name}</h4>
2218
  <div class="model-details">
2219
  <p>Max Tokens: {config['max_tokens']:,}</p>
2220
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
 
2221
  </div>
2222
  </div>
2223
  """, unsafe_allow_html=True)
@@ -2242,28 +2247,19 @@ class MyScene(Scene):
2242
  model_name = st.session_state.custom_model
2243
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2244
 
2245
- # Use appropriate parameters based on model configuration
2246
- api_params = {
2247
- "messages": [UserMessage("Hello")],
2248
- "model": model_name
2249
- }
2250
- api_params[config["param_name"]] = 10 # Just request 10 tokens for quick test
2251
 
2252
- if config["api_version"]:
2253
- # Create version-specific client if needed
2254
- token = get_secret("github_token_api")
2255
- from azure.ai.inference import ChatCompletionsClient
2256
- from azure.core.credentials import AzureKeyCredential
2257
-
2258
- client = ChatCompletionsClient(
2259
- endpoint=st.session_state.ai_models["endpoint"],
2260
- credential=AzureKeyCredential(token),
2261
- api_version=config["api_version"]
2262
- )
2263
- response = client.complete(**api_params)
2264
- else:
2265
- response = st.session_state.ai_models["client"].complete(**api_params)
2266
-
2267
  st.success(f"✅ Connection to {model_name} successful!")
2268
  st.session_state.ai_models["model_name"] = model_name
2269
 
@@ -2305,13 +2301,19 @@ class MyScene(Scene):
2305
  if code_input:
2306
  with st.spinner("AI is generating your animation code..."):
2307
  try:
2308
- # Direct implementation of code generation
2309
  client = st.session_state.ai_models["client"]
2310
  model_name = st.session_state.ai_models["model_name"]
2311
 
2312
  # Get configuration for this model
2313
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2314
 
 
 
 
 
 
 
2315
  # Create the prompt
2316
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2317
  {code_input}
@@ -2325,42 +2327,13 @@ The code should be a complete, working Manim animation that includes:
2325
  Here's the complete Manim code:
2326
  """
2327
 
2328
- # Prepare API call parameters based on model requirements
2329
- api_params = {
2330
- "messages": [UserMessage(prompt)],
2331
- "model": model_name
2332
- }
2333
-
2334
- # Add the appropriate token parameter
2335
- api_params[config["param_name"]] = config["max_tokens"]
2336
-
2337
- # Check if we need to specify API version
2338
- if config["api_version"]:
2339
- # If we need a specific API version, create a new client with that version
2340
- logger.info(f"Using API version {config['api_version']} for model {model_name}")
2341
-
2342
- # Get token from session state
2343
- token = get_secret("github_token_api")
2344
- if not token:
2345
- st.error("GitHub token not found in secrets")
2346
- return None
2347
-
2348
- # Import required modules for creating client with specific API version
2349
- from azure.ai.inference import ChatCompletionsClient
2350
- from azure.core.credentials import AzureKeyCredential
2351
-
2352
- # Create client with specific API version
2353
- version_specific_client = ChatCompletionsClient(
2354
- endpoint=st.session_state.ai_models["endpoint"],
2355
- credential=AzureKeyCredential(token),
2356
- api_version=config["api_version"]
2357
- )
2358
-
2359
- # Make the API call with the version-specific client
2360
- response = version_specific_client.complete(**api_params)
2361
- else:
2362
- # Use the existing client
2363
- response = client.complete(**api_params)
2364
 
2365
  # Process the response
2366
  if response and response.choices and len(response.choices) > 0:
 
2141
  # Get model configuration
2142
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2143
 
2144
+ # Check if this is a model that requires max_completion_tokens
2145
+ if config["param_name"] == "max_completion_tokens":
2146
+ st.warning(f"Model {model_name} requires special handling that isn't fully supported. Testing with gpt-4o instead.")
2147
+ model_name = "gpt-4o"
2148
+ config = MODEL_CONFIGS["gpt-4o"]
2149
+
2150
  # Create client with appropriate API version
2151
  api_version = config.get("api_version")
2152
  if api_version:
 
2162
  )
2163
 
2164
  # Test with a simple prompt
2165
+ response = client.complete(
2166
+ messages=[UserMessage("Hello, this is a connection test.")],
2167
+ model=model_name,
2168
+ max_tokens=50 # Small value for quick testing
2169
+ )
 
 
 
 
 
2170
 
2171
  # Check if response is valid
2172
  if response and response.choices and len(response.choices) > 0:
 
2211
  for model_name in sorted(model_categories[category]):
2212
  config = MODEL_CONFIGS[model_name]
2213
  is_selected = model_name == st.session_state.custom_model
2214
+ warning = config.get("warning")
2215
 
2216
  # Create styled card for each model
2217
+ warning_html = f'<p style="color: #ff9800; font-size: 0.8rem; margin-top: 5px;">⚠️ {warning}</p>' if warning else ""
2218
+
2219
  st.markdown(f"""
2220
  <div class="model-card {'selected-model' if is_selected else ''}">
2221
  <h4>{model_name}</h4>
2222
  <div class="model-details">
2223
  <p>Max Tokens: {config['max_tokens']:,}</p>
2224
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
2225
+ {warning_html}
2226
  </div>
2227
  </div>
2228
  """, unsafe_allow_html=True)
 
2247
  model_name = st.session_state.custom_model
2248
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2249
 
2250
+ # Check if this is a model that requires max_completion_tokens
2251
+ if config["param_name"] == "max_completion_tokens":
2252
+ st.warning(f"Model {model_name} requires special handling that isn't fully supported. Using gpt-4o instead.")
2253
+ model_name = "gpt-4o"
2254
+ config = MODEL_CONFIGS["gpt-4o"]
 
2255
 
2256
+ # Use appropriate parameters based on model configuration
2257
+ response = st.session_state.ai_models["client"].complete(
2258
+ messages=[UserMessage("Hello")],
2259
+ model=model_name,
2260
+ max_tokens=10 # Just request 10 tokens for quick test
2261
+ )
2262
+
 
 
 
 
 
 
 
 
2263
  st.success(f"✅ Connection to {model_name} successful!")
2264
  st.session_state.ai_models["model_name"] = model_name
2265
 
 
2301
  if code_input:
2302
  with st.spinner("AI is generating your animation code..."):
2303
  try:
2304
+ # Get the client and model name
2305
  client = st.session_state.ai_models["client"]
2306
  model_name = st.session_state.ai_models["model_name"]
2307
 
2308
  # Get configuration for this model
2309
  config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2310
 
2311
+ # Check if this model requires max_completion_tokens
2312
+ if config["param_name"] == "max_completion_tokens":
2313
+ st.warning(f"Model {model_name} requires special handling that isn't fully supported. Using gpt-4o instead.")
2314
+ model_name = "gpt-4o"
2315
+ config = MODEL_CONFIGS["gpt-4o"]
2316
+
2317
  # Create the prompt
2318
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2319
  {code_input}
 
2327
  Here's the complete Manim code:
2328
  """
2329
 
2330
+ # Make the API call with proper parameters
2331
+ from azure.ai.inference.models import UserMessage
2332
+ response = client.complete(
2333
+ messages=[UserMessage(prompt)],
2334
+ model=model_name,
2335
+ max_tokens=config["max_tokens"]
2336
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2337
 
2338
  # Process the response
2339
  if response and response.choices and len(response.choices) > 0: