euler314 commited on
Commit
cd66e4d
·
verified ·
1 Parent(s): 0242199

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -80
app.py CHANGED
@@ -71,6 +71,25 @@ except ImportError:
71
  ACE_EDITOR_AVAILABLE = False
72
  logger.warning("streamlit-ace not available, falling back to standard text editor")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # New functions for accessing secrets and password verification
75
  def get_secret(github_token_api):
76
  """Retrieve a secret from HuggingFace Spaces environment variables"""
@@ -331,12 +350,9 @@ Here's the complete Manim code:
331
  # Get the current model name
332
  model_name = models["model_name"]
333
 
334
- # Get configuration for this model (or use default)
335
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
336
-
337
- # Check if this is a model that requires max_completion_tokens
338
- # Prepare common API parameters
339
  messages = [UserMessage(prompt)]
 
340
 
341
  # Check if we need to specify API version
342
  if config["api_version"]:
@@ -360,21 +376,11 @@ Here's the complete Manim code:
360
  api_version=config["api_version"]
361
  )
362
 
363
- # Make the API call with the version-specific client and normal max_tokens
364
- response = version_specific_client.complete(
365
- messages=messages,
366
- model=model_name,
367
- max_tokens=config["max_tokens"],
368
- max_completion_tokens=config["max_completion_tokens"]
369
- )
370
  else:
371
- # Use the existing client with normal max_tokens
372
- response = models["client"].complete(
373
- messages=messages,
374
- model=model_name,
375
- max_tokens=config["max_tokens"],
376
- max_completion_tokens=config["max_completion_tokens"]
377
- )
378
 
379
  # Process the response
380
  completed_code = response.choices[0].message.content
@@ -2133,14 +2139,9 @@ class MyScene(Scene):
2133
  endpoint = "https://models.inference.ai.azure.com"
2134
  model_name = st.session_state.custom_model
2135
 
2136
- # Get model configuration
2137
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2138
-
2139
- # Check if this is a model that requires max_completion_tokens
2140
- if config["param_name"] == "max_com343434pletion_tokens":
2141
- st.warning(f"Model {model_name} requires special handling that isn't fully supported. Testing with gpt-4o instead.")
2142
- model_name = "gpt-4o"
2143
- config = MODEL_CONFIGS["gpt-4o"]
2144
 
2145
  # Create client with appropriate API version
2146
  api_version = config.get("api_version")
@@ -2156,13 +2157,8 @@ class MyScene(Scene):
2156
  credential=AzureKeyCredential(token),
2157
  )
2158
 
2159
- # Test with a simple prompt
2160
- response = client.complete(
2161
- messages=[UserMessage("Hello, this is a connection test.")],
2162
- model=model_name,
2163
- max_tokens=50 , # Small value for quick testing
2164
- max_completion_tokens=50
2165
- )
2166
 
2167
  # Check if response is valid
2168
  if response and response.choices and len(response.choices) > 0:
@@ -2185,7 +2181,7 @@ class MyScene(Scene):
2185
  st.error(f"❌ API test failed: {str(e)}")
2186
  import traceback
2187
  st.code(traceback.format_exc())
2188
-
2189
  # Model selection with enhanced UI
2190
  st.markdown("### 🤖 Model Selection")
2191
  st.markdown("Select an AI model for generating animation code:")
@@ -2216,7 +2212,7 @@ class MyScene(Scene):
2216
  <div class="model-card {'selected-model' if is_selected else ''}">
2217
  <h4>{model_name}</h4>
2218
  <div class="model-details">
2219
- <p>Max Tokens: {config['max_tokens']:,}</p>
2220
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
2221
  {warning_html}
2222
  </div>
@@ -2241,21 +2237,31 @@ class MyScene(Scene):
2241
  # Test connection with minimal prompt
2242
  from azure.ai.inference.models import UserMessage
2243
  model_name = st.session_state.custom_model
2244
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2245
 
2246
- # Check if this is a model that requires max_completion_tokens
2247
- if config["param_name"] == "max_completion_tokens":
2248
- st.warning(f"Model {model_name} requires special handling that isn't fully supported. Using gpt-4o instead.")
2249
- model_name = "gpt-4o"
2250
- config = MODEL_CONFIGS["gpt-4o"]
2251
 
2252
- # Use appropriate parameters based on model configuration
2253
- response = st.session_state.ai_models["client"].complete(
2254
- messages=[UserMessage("Hello")],
2255
- model=model_name,
2256
- max_tokens=10 # Just request 10 tokens for quick test
2257
- )
2258
-
 
 
 
 
 
 
 
 
 
 
 
 
 
2259
  st.success(f"✅ Connection to {model_name} successful!")
2260
  st.session_state.ai_models["model_name"] = model_name
2261
 
@@ -2301,35 +2307,26 @@ class MyScene(Scene):
2301
  client = st.session_state.ai_models["client"]
2302
  model_name = st.session_state.ai_models["model_name"]
2303
 
2304
- # Get configuration for this model
2305
- config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
2306
-
2307
- # Check if this model requires max_completion_tokens
2308
- if config["param_name"] == "max_completion_tokens":
2309
- st.warning(f"Model {model_name} requires special handling that isn't fully supported. Using gpt-4o instead.")
2310
- model_name = "gpt-4o"
2311
- config = MODEL_CONFIGS["gpt-4o"]
2312
-
2313
  # Create the prompt
2314
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2315
- {code_input}
2316
-
2317
- The code should be a complete, working Manim animation that includes:
2318
- - Proper Scene class definition
2319
- - Constructor with animations
2320
- - Proper use of self.play() for animations
2321
- - Proper wait times between animations
2322
-
2323
- Here's the complete Manim code:
2324
- """
2325
 
2326
- # Make the API call with proper parameters
2327
  from azure.ai.inference.models import UserMessage
2328
- response = client.complete(
2329
- messages=[UserMessage(prompt)],
2330
- model=model_name,
2331
- max_tokens=config["max_tokens"]
2332
- )
2333
 
2334
  # Process the response
2335
  if response and response.choices and len(response.choices) > 0:
@@ -2344,10 +2341,10 @@ Here's the complete Manim code:
2344
  # Add Scene class if missing
2345
  if "Scene" not in completed_code:
2346
  completed_code = f"""from manim import *
2347
-
2348
- class MyScene(Scene):
2349
- def construct(self):
2350
- {completed_code}"""
2351
 
2352
  # Store the generated code
2353
  st.session_state.generated_code = completed_code
@@ -2359,8 +2356,7 @@ class MyScene(Scene):
2359
  st.code(traceback.format_exc())
2360
  else:
2361
  st.warning("Please enter a description or prompt first")
2362
-
2363
- st.markdown("</div>", unsafe_allow_html=True)
2364
 
2365
  # AI generated code display and actions
2366
  if "generated_code" in st.session_state and st.session_state.generated_code:
 
71
  ACE_EDITOR_AVAILABLE = False
72
  logger.warning("streamlit-ace not available, falling back to standard text editor")
73
 
74
+ def prepare_api_params(messages, model_name):
75
+ """Create appropriate API parameters based on model configuration"""
76
+ # Get model configuration
77
+ config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
78
+
79
+ # Base parameters common to all models
80
+ api_params = {
81
+ "messages": messages,
82
+ "model": model_name
83
+ }
84
+
85
+ # Add the appropriate token parameter based on model
86
+ if config["param_name"] == "max_completion_tokens":
87
+ api_params["max_completion_tokens"] = config["max_tokens"]
88
+ else:
89
+ api_params["max_tokens"] = config["max_tokens"]
90
+
91
+ return api_params, config
92
+
93
  # New functions for accessing secrets and password verification
94
  def get_secret(github_token_api):
95
  """Retrieve a secret from HuggingFace Spaces environment variables"""
 
350
  # Get the current model name
351
  model_name = models["model_name"]
352
 
353
+ # Prepare API parameters based on model
 
 
 
 
354
  messages = [UserMessage(prompt)]
355
+ api_params, config = prepare_api_params(messages, model_name)
356
 
357
  # Check if we need to specify API version
358
  if config["api_version"]:
 
376
  api_version=config["api_version"]
377
  )
378
 
379
+ # Make the API call with the version-specific client
380
+ response = version_specific_client.complete(**api_params)
 
 
 
 
 
381
  else:
382
+ # Use the existing client
383
+ response = models["client"].complete(**api_params)
 
 
 
 
 
384
 
385
  # Process the response
386
  completed_code = response.choices[0].message.content
 
2139
  endpoint = "https://models.inference.ai.azure.com"
2140
  model_name = st.session_state.custom_model
2141
 
2142
+ # Prepare API parameters
2143
+ messages = [UserMessage("Hello, this is a connection test.")]
2144
+ api_params, config = prepare_api_params(messages, model_name)
 
 
 
 
 
2145
 
2146
  # Create client with appropriate API version
2147
  api_version = config.get("api_version")
 
2157
  credential=AzureKeyCredential(token),
2158
  )
2159
 
2160
+ # Test with the prepared parameters
2161
+ response = client.complete(**api_params)
 
 
 
 
 
2162
 
2163
  # Check if response is valid
2164
  if response and response.choices and len(response.choices) > 0:
 
2181
  st.error(f"❌ API test failed: {str(e)}")
2182
  import traceback
2183
  st.code(traceback.format_exc())
2184
+
2185
  # Model selection with enhanced UI
2186
  st.markdown("### 🤖 Model Selection")
2187
  st.markdown("Select an AI model for generating animation code:")
 
2212
  <div class="model-card {'selected-model' if is_selected else ''}">
2213
  <h4>{model_name}</h4>
2214
  <div class="model-details">
2215
+ <p>Max Tokens: {config['max_tokens']:,config[max]}</p>
2216
  <p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
2217
  {warning_html}
2218
  </div>
 
2237
  # Test connection with minimal prompt
2238
  from azure.ai.inference.models import UserMessage
2239
  model_name = st.session_state.custom_model
 
2240
 
2241
+ # Prepare parameters
2242
+ messages = [UserMessage("Hello")]
2243
+ api_params, config = prepare_api_params(messages, model_name)
 
 
2244
 
2245
+ # Check if we need a new client with specific API version
2246
+ if config["api_version"] and config["api_version"] != st.session_state.ai_models.get("api_version"):
2247
+ # Create version-specific client if needed
2248
+ token = get_secret("github_token_api")
2249
+ from azure.ai.inference import ChatCompletionsClient
2250
+ from azure.core.credentials import AzureKeyCredential
2251
+
2252
+ client = ChatCompletionsClient(
2253
+ endpoint=st.session_state.ai_models["endpoint"],
2254
+ credential=AzureKeyCredential(token),
2255
+ api_version=config["api_version"]
2256
+ )
2257
+ response = client.complete(**api_params)
2258
+
2259
+ # Update session state with the new client
2260
+ st.session_state.ai_models["client"] = client
2261
+ st.session_state.ai_models["api_version"] = config["api_version"]
2262
+ else:
2263
+ response = st.session_state.ai_models["client"].complete(**api_params)
2264
+
2265
  st.success(f"✅ Connection to {model_name} successful!")
2266
  st.session_state.ai_models["model_name"] = model_name
2267
 
 
2307
  client = st.session_state.ai_models["client"]
2308
  model_name = st.session_state.ai_models["model_name"]
2309
 
 
 
 
 
 
 
 
 
 
2310
  # Create the prompt
2311
  prompt = f"""Write a complete Manim animation scene based on this code or idea:
2312
+ {code_input}
2313
+
2314
+ The code should be a complete, working Manim animation that includes:
2315
+ - Proper Scene class definition
2316
+ - Constructor with animations
2317
+ - Proper use of self.play() for animations
2318
+ - Proper wait times between animations
2319
+
2320
+ Here's the complete Manim code:
2321
+ """
2322
 
2323
+ # Prepare API parameters
2324
  from azure.ai.inference.models import UserMessage
2325
+ messages = [UserMessage(prompt)]
2326
+ api_params, config = prepare_api_params(messages, model_name)
2327
+
2328
+ # Make the API call with proper parameters
2329
+ response = client.complete(**api_params)
2330
 
2331
  # Process the response
2332
  if response and response.choices and len(response.choices) > 0:
 
2341
  # Add Scene class if missing
2342
  if "Scene" not in completed_code:
2343
  completed_code = f"""from manim import *
2344
+
2345
+ class MyScene(Scene):
2346
+ def construct(self):
2347
+ {completed_code}"""
2348
 
2349
  # Store the generated code
2350
  st.session_state.generated_code = completed_code
 
2356
  st.code(traceback.format_exc())
2357
  else:
2358
  st.warning("Please enter a description or prompt first")
2359
+
 
2360
 
2361
  # AI generated code display and actions
2362
  if "generated_code" in st.session_state and st.session_state.generated_code: