Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -41,6 +41,29 @@ logging.basicConfig(
|
|
41 |
)
|
42 |
logger = logging.getLogger(__name__)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Try to import Streamlit Ace
|
45 |
try:
|
46 |
from streamlit_ace import st_ace
|
@@ -306,16 +329,50 @@ Here's the complete Manim code:
|
|
306 |
with st.spinner("AI is generating your animation code..."):
|
307 |
from azure.ai.inference.models import UserMessage
|
308 |
|
309 |
-
#
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
|
|
|
|
|
|
317 |
|
318 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
completed_code = response.choices[0].message.content
|
320 |
|
321 |
# Process the code
|
@@ -338,6 +395,7 @@ class MyScene(Scene):
|
|
338 |
st.error(f"Error generating code: {str(e)}")
|
339 |
st.code(traceback.format_exc())
|
340 |
return None
|
|
|
341 |
def check_model_freshness():
|
342 |
"""Check if models need to be reloaded based on TTL"""
|
343 |
if 'ai_models' not in st.session_state or st.session_state.ai_models is None:
|
@@ -1635,21 +1693,86 @@ def main():
|
|
1635 |
margin-bottom: 1rem;
|
1636 |
text-align: center;
|
1637 |
}
|
1638 |
-
|
|
|
1639 |
.card {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1640 |
background-color: #f8f9fa;
|
1641 |
border-radius: 10px;
|
1642 |
-
padding:
|
1643 |
-
|
1644 |
-
|
|
|
1645 |
}
|
1646 |
-
|
1647 |
-
|
1648 |
-
|
1649 |
-
border-radius: 5px;
|
1650 |
-
height: 2.5rem;
|
1651 |
}
|
1652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1653 |
.preview-container {
|
1654 |
border: 1px solid #e0e0e0;
|
1655 |
border-radius: 10px;
|
@@ -1657,7 +1780,7 @@ def main():
|
|
1657 |
margin-bottom: 1rem;
|
1658 |
min-height: 200px;
|
1659 |
}
|
1660 |
-
|
1661 |
.latex-preview {
|
1662 |
background-color: #f8f9fa;
|
1663 |
border-radius: 5px;
|
@@ -1665,12 +1788,12 @@ def main():
|
|
1665 |
margin-top: 0.5rem;
|
1666 |
min-height: 100px;
|
1667 |
}
|
1668 |
-
|
1669 |
.small-text {
|
1670 |
font-size: 0.8rem;
|
1671 |
color: #6c757d;
|
1672 |
}
|
1673 |
-
|
1674 |
.asset-card {
|
1675 |
background-color: #f0f2f5;
|
1676 |
border-radius: 8px;
|
@@ -1678,14 +1801,14 @@ def main():
|
|
1678 |
margin-bottom: 1rem;
|
1679 |
border-left: 4px solid #4F46E5;
|
1680 |
}
|
1681 |
-
|
1682 |
.timeline-container {
|
1683 |
background-color: #f8f9fa;
|
1684 |
border-radius: 10px;
|
1685 |
padding: 1.5rem;
|
1686 |
margin-bottom: 1.5rem;
|
1687 |
}
|
1688 |
-
|
1689 |
.keyframe {
|
1690 |
width: 12px;
|
1691 |
height: 12px;
|
@@ -1695,7 +1818,7 @@ def main():
|
|
1695 |
transform: translate(-50%, -50%);
|
1696 |
cursor: pointer;
|
1697 |
}
|
1698 |
-
|
1699 |
.educational-export-container {
|
1700 |
background-color: #f0f7ff;
|
1701 |
border-radius: 10px;
|
@@ -1703,7 +1826,7 @@ def main():
|
|
1703 |
margin-bottom: 1.5rem;
|
1704 |
border: 1px solid #c2e0ff;
|
1705 |
}
|
1706 |
-
|
1707 |
.code-output {
|
1708 |
background-color: #f8f9fa;
|
1709 |
border-radius: 8px;
|
@@ -1713,7 +1836,7 @@ def main():
|
|
1713 |
max-height: 400px;
|
1714 |
overflow-y: auto;
|
1715 |
}
|
1716 |
-
|
1717 |
.error-output {
|
1718 |
background-color: #fef2f2;
|
1719 |
border-radius: 8px;
|
@@ -2004,22 +2127,36 @@ class MyScene(Scene):
|
|
2004 |
|
2005 |
# Define endpoint
|
2006 |
endpoint = "https://models.inference.ai.azure.com"
|
2007 |
-
model_name =
|
2008 |
|
2009 |
-
#
|
2010 |
-
|
2011 |
-
|
2012 |
-
|
2013 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2014 |
|
2015 |
# Test with a simple prompt
|
2016 |
-
|
2017 |
-
messages
|
2018 |
-
|
2019 |
-
|
2020 |
-
|
2021 |
-
|
2022 |
-
|
|
|
|
|
|
|
2023 |
|
2024 |
# Check if response is valid
|
2025 |
if response and response.choices and len(response.choices) > 0:
|
@@ -2043,44 +2180,87 @@ class MyScene(Scene):
|
|
2043 |
import traceback
|
2044 |
st.code(traceback.format_exc())
|
2045 |
|
2046 |
-
# Model selection
|
2047 |
-
st.markdown("
|
2048 |
-
|
2049 |
-
|
2050 |
-
|
2051 |
-
|
2052 |
-
|
2053 |
-
"
|
2054 |
-
|
2055 |
-
|
2056 |
-
|
2057 |
-
|
2058 |
-
|
2059 |
-
|
2060 |
-
|
2061 |
-
|
2062 |
-
|
2063 |
-
|
2064 |
-
|
2065 |
-
|
2066 |
-
|
2067 |
-
|
2068 |
-
|
2069 |
-
|
2070 |
-
|
2071 |
-
|
2072 |
-
|
2073 |
-
|
2074 |
-
|
2075 |
-
|
2076 |
-
|
2077 |
-
|
2078 |
-
|
2079 |
-
|
2080 |
-
|
2081 |
-
|
2082 |
-
|
2083 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2084 |
|
2085 |
# AI code generation
|
2086 |
if st.session_state.ai_models and "client" in st.session_state.ai_models:
|
@@ -2120,6 +2300,9 @@ class MyScene(Scene):
|
|
2120 |
client = st.session_state.ai_models["client"]
|
2121 |
model_name = st.session_state.ai_models["model_name"]
|
2122 |
|
|
|
|
|
|
|
2123 |
# Create the prompt
|
2124 |
prompt = f"""Write a complete Manim animation scene based on this code or idea:
|
2125 |
{code_input}
|
@@ -2133,15 +2316,42 @@ The code should be a complete, working Manim animation that includes:
|
|
2133 |
Here's the complete Manim code:
|
2134 |
"""
|
2135 |
|
2136 |
-
#
|
2137 |
-
|
2138 |
-
|
2139 |
-
|
2140 |
-
|
2141 |
-
|
2142 |
-
|
2143 |
-
|
2144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2145 |
|
2146 |
# Process the response
|
2147 |
if response and response.choices and len(response.choices) > 0:
|
|
|
41 |
)
|
42 |
logger = logging.getLogger(__name__)
|
43 |
|
44 |
+
# Model configuration mapping for different API requirements and limits
|
45 |
+
MODEL_CONFIGS = {
|
46 |
+
"DeepSeek-V3-0324": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "DeepSeek"},
|
47 |
+
"DeepSeek-R1": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "DeepSeek"},
|
48 |
+
"Meta-Llama-3.1-405B-Instruct": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Meta"},
|
49 |
+
"Llama-3.2-90B-Vision-Instruct": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Meta"},
|
50 |
+
"Llama-3.3-70B-Instruct": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Meta"},
|
51 |
+
"Llama-4-Scout-17B-16E-Instruct": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Meta"},
|
52 |
+
"Llama-4-Maverick-17B-128E-Instruct-FP8": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Meta"},
|
53 |
+
"gpt-4o-mini": {"max_tokens": 16000, "param_name": "max_tokens", "api_version": None, "category": "OpenAI"},
|
54 |
+
"gpt-4.1": {"max_tokens": 16000, "param_name": "max_tokens", "api_version": None, "category": "OpenAI"},
|
55 |
+
"gpt-4o": {"max_tokens": 16000, "param_name": "max_tokens", "api_version": None, "category": "OpenAI"},
|
56 |
+
"o3-mini": {"max_tokens": 16000, "param_name": "max_tokens", "api_version": "2024-12-01-preview", "category": "Anthropic"},
|
57 |
+
"o1": {"max_tokens": 16000, "param_name": "max_completion_tokens", "api_version": "2024-12-01-preview", "category": "Anthropic"},
|
58 |
+
"o1-mini": {"max_tokens": 16000, "param_name": "max_completion_tokens", "api_version": "2024-12-01-preview", "category": "Anthropic"},
|
59 |
+
"o1-preview": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Anthropic"},
|
60 |
+
"Phi-4-multimodal-instruct": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Microsoft"},
|
61 |
+
"Mistral-large-2407": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Mistral"},
|
62 |
+
"Codestral-2501": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Mistral"},
|
63 |
+
# Default configuration for other models
|
64 |
+
"default": {"max_tokens": 4000, "param_name": "max_tokens", "api_version": None, "category": "Other"}
|
65 |
+
}
|
66 |
+
|
67 |
# Try to import Streamlit Ace
|
68 |
try:
|
69 |
from streamlit_ace import st_ace
|
|
|
329 |
with st.spinner("AI is generating your animation code..."):
|
330 |
from azure.ai.inference.models import UserMessage
|
331 |
|
332 |
+
# Get the current model name
|
333 |
+
model_name = models["model_name"]
|
334 |
+
|
335 |
+
# Get configuration for this model (or use default)
|
336 |
+
config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
|
337 |
+
|
338 |
+
# Prepare API call parameters based on model requirements
|
339 |
+
api_params = {
|
340 |
+
"messages": [UserMessage(prompt)],
|
341 |
+
"model": model_name
|
342 |
+
}
|
343 |
|
344 |
+
# Add the appropriate token parameter
|
345 |
+
api_params[config["param_name"]] = config["max_tokens"]
|
346 |
+
|
347 |
+
# Check if we need to specify API version
|
348 |
+
if config["api_version"]:
|
349 |
+
# If we need a specific API version, we need to create a new client with that version
|
350 |
+
logger.info(f"Using API version {config['api_version']} for model {model_name}")
|
351 |
+
|
352 |
+
# Get token from session state
|
353 |
+
token = get_secret("github_token_api")
|
354 |
+
if not token:
|
355 |
+
st.error("GitHub token not found in secrets")
|
356 |
+
return None
|
357 |
+
|
358 |
+
# Import required modules for creating client with specific API version
|
359 |
+
from azure.ai.inference import ChatCompletionsClient
|
360 |
+
from azure.core.credentials import AzureKeyCredential
|
361 |
+
|
362 |
+
# Create client with specific API version
|
363 |
+
version_specific_client = ChatCompletionsClient(
|
364 |
+
endpoint=models["endpoint"],
|
365 |
+
credential=AzureKeyCredential(token),
|
366 |
+
api_version=config["api_version"]
|
367 |
+
)
|
368 |
+
|
369 |
+
# Make the API call with the version-specific client
|
370 |
+
response = version_specific_client.complete(**api_params)
|
371 |
+
else:
|
372 |
+
# Use the existing client
|
373 |
+
response = models["client"].complete(**api_params)
|
374 |
+
|
375 |
+
# Process the response
|
376 |
completed_code = response.choices[0].message.content
|
377 |
|
378 |
# Process the code
|
|
|
395 |
st.error(f"Error generating code: {str(e)}")
|
396 |
st.code(traceback.format_exc())
|
397 |
return None
|
398 |
+
|
399 |
def check_model_freshness():
|
400 |
"""Check if models need to be reloaded based on TTL"""
|
401 |
if 'ai_models' not in st.session_state or st.session_state.ai_models is None:
|
|
|
1693 |
margin-bottom: 1rem;
|
1694 |
text-align: center;
|
1695 |
}
|
1696 |
+
|
1697 |
+
/* Improved Cards */
|
1698 |
.card {
|
1699 |
+
background-color: #ffffff;
|
1700 |
+
border-radius: 12px;
|
1701 |
+
padding: 1.8rem;
|
1702 |
+
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.08);
|
1703 |
+
margin-bottom: 1.8rem;
|
1704 |
+
border-left: 5px solid #4F46E5;
|
1705 |
+
transition: all 0.3s ease;
|
1706 |
+
}
|
1707 |
+
.card:hover {
|
1708 |
+
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.12);
|
1709 |
+
transform: translateY(-2px);
|
1710 |
+
}
|
1711 |
+
|
1712 |
+
/* Tab styling */
|
1713 |
+
.stTabs [data-baseweb="tab-list"] {
|
1714 |
+
gap: 2px;
|
1715 |
+
}
|
1716 |
+
.stTabs [data-baseweb="tab"] {
|
1717 |
+
height: 45px;
|
1718 |
+
white-space: pre-wrap;
|
1719 |
+
border-radius: 4px 4px 0 0;
|
1720 |
+
font-weight: 500;
|
1721 |
+
}
|
1722 |
+
.stTabs [aria-selected="true"] {
|
1723 |
+
background-color: #f0f4fd;
|
1724 |
+
border-bottom: 2px solid #4F46E5;
|
1725 |
+
}
|
1726 |
+
|
1727 |
+
/* Buttons */
|
1728 |
+
.stButton button {
|
1729 |
+
border-radius: 6px;
|
1730 |
+
font-weight: 500;
|
1731 |
+
transition: all 0.2s ease;
|
1732 |
+
}
|
1733 |
+
.stButton button:hover {
|
1734 |
+
transform: translateY(-1px);
|
1735 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
1736 |
+
}
|
1737 |
+
|
1738 |
+
/* Model selection */
|
1739 |
+
.model-group {
|
1740 |
+
margin-bottom: 1.5rem;
|
1741 |
+
padding: 15px;
|
1742 |
+
border-radius: 8px;
|
1743 |
+
background-color: #f8f9fa;
|
1744 |
+
}
|
1745 |
+
|
1746 |
+
.model-card {
|
1747 |
background-color: #f8f9fa;
|
1748 |
border-radius: 10px;
|
1749 |
+
padding: 15px;
|
1750 |
+
margin-bottom: 10px;
|
1751 |
+
border-left: 4px solid #4F46E5;
|
1752 |
+
transition: all 0.3s ease;
|
1753 |
}
|
1754 |
+
.model-card:hover {
|
1755 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
|
1756 |
+
transform: translateY(-2px);
|
|
|
|
|
1757 |
}
|
1758 |
+
.model-category {
|
1759 |
+
font-size: 1.2rem;
|
1760 |
+
font-weight: 600;
|
1761 |
+
padding: 10px 5px;
|
1762 |
+
margin-top: 15px;
|
1763 |
+
border-bottom: 2px solid #e9ecef;
|
1764 |
+
color: #333;
|
1765 |
+
}
|
1766 |
+
.model-details {
|
1767 |
+
font-size: 0.8rem;
|
1768 |
+
color: #666;
|
1769 |
+
margin-top: 5px;
|
1770 |
+
}
|
1771 |
+
.selected-model {
|
1772 |
+
background-color: #e8f4fe;
|
1773 |
+
border-left: 4px solid #0d6efd;
|
1774 |
+
}
|
1775 |
+
|
1776 |
.preview-container {
|
1777 |
border: 1px solid #e0e0e0;
|
1778 |
border-radius: 10px;
|
|
|
1780 |
margin-bottom: 1rem;
|
1781 |
min-height: 200px;
|
1782 |
}
|
1783 |
+
|
1784 |
.latex-preview {
|
1785 |
background-color: #f8f9fa;
|
1786 |
border-radius: 5px;
|
|
|
1788 |
margin-top: 0.5rem;
|
1789 |
min-height: 100px;
|
1790 |
}
|
1791 |
+
|
1792 |
.small-text {
|
1793 |
font-size: 0.8rem;
|
1794 |
color: #6c757d;
|
1795 |
}
|
1796 |
+
|
1797 |
.asset-card {
|
1798 |
background-color: #f0f2f5;
|
1799 |
border-radius: 8px;
|
|
|
1801 |
margin-bottom: 1rem;
|
1802 |
border-left: 4px solid #4F46E5;
|
1803 |
}
|
1804 |
+
|
1805 |
.timeline-container {
|
1806 |
background-color: #f8f9fa;
|
1807 |
border-radius: 10px;
|
1808 |
padding: 1.5rem;
|
1809 |
margin-bottom: 1.5rem;
|
1810 |
}
|
1811 |
+
|
1812 |
.keyframe {
|
1813 |
width: 12px;
|
1814 |
height: 12px;
|
|
|
1818 |
transform: translate(-50%, -50%);
|
1819 |
cursor: pointer;
|
1820 |
}
|
1821 |
+
|
1822 |
.educational-export-container {
|
1823 |
background-color: #f0f7ff;
|
1824 |
border-radius: 10px;
|
|
|
1826 |
margin-bottom: 1.5rem;
|
1827 |
border: 1px solid #c2e0ff;
|
1828 |
}
|
1829 |
+
|
1830 |
.code-output {
|
1831 |
background-color: #f8f9fa;
|
1832 |
border-radius: 8px;
|
|
|
1836 |
max-height: 400px;
|
1837 |
overflow-y: auto;
|
1838 |
}
|
1839 |
+
|
1840 |
.error-output {
|
1841 |
background-color: #fef2f2;
|
1842 |
border-radius: 8px;
|
|
|
2127 |
|
2128 |
# Define endpoint
|
2129 |
endpoint = "https://models.inference.ai.azure.com"
|
2130 |
+
model_name = st.session_state.custom_model
|
2131 |
|
2132 |
+
# Get model configuration
|
2133 |
+
config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
|
2134 |
+
|
2135 |
+
# Create client with appropriate API version
|
2136 |
+
api_version = config.get("api_version")
|
2137 |
+
if api_version:
|
2138 |
+
client = ChatCompletionsClient(
|
2139 |
+
endpoint=endpoint,
|
2140 |
+
credential=AzureKeyCredential(token),
|
2141 |
+
api_version=api_version
|
2142 |
+
)
|
2143 |
+
else:
|
2144 |
+
client = ChatCompletionsClient(
|
2145 |
+
endpoint=endpoint,
|
2146 |
+
credential=AzureKeyCredential(token),
|
2147 |
+
)
|
2148 |
|
2149 |
# Test with a simple prompt
|
2150 |
+
api_params = {
|
2151 |
+
"messages": [UserMessage("Hello, this is a connection test.")],
|
2152 |
+
"model": model_name
|
2153 |
+
}
|
2154 |
+
|
2155 |
+
# Use appropriate parameter name
|
2156 |
+
api_params[config["param_name"]] = 100 # Just enough for a short response
|
2157 |
+
|
2158 |
+
# Make the API call
|
2159 |
+
response = client.complete(**api_params)
|
2160 |
|
2161 |
# Check if response is valid
|
2162 |
if response and response.choices and len(response.choices) > 0:
|
|
|
2180 |
import traceback
|
2181 |
st.code(traceback.format_exc())
|
2182 |
|
2183 |
+
# Model selection with enhanced UI
|
2184 |
+
st.markdown("### π€ Model Selection")
|
2185 |
+
st.markdown("Select an AI model for generating animation code:")
|
2186 |
+
|
2187 |
+
# Group models by category for better organization
|
2188 |
+
model_categories = {}
|
2189 |
+
for model_name in MODEL_CONFIGS:
|
2190 |
+
if model_name != "default":
|
2191 |
+
category = MODEL_CONFIGS[model_name].get("category", "Other")
|
2192 |
+
if category not in model_categories:
|
2193 |
+
model_categories[category] = []
|
2194 |
+
model_categories[category].append(model_name)
|
2195 |
+
|
2196 |
+
# Create tabbed interface for model categories
|
2197 |
+
category_tabs = st.tabs(sorted(model_categories.keys()))
|
2198 |
+
|
2199 |
+
for i, category in enumerate(sorted(model_categories.keys())):
|
2200 |
+
with category_tabs[i]:
|
2201 |
+
for model_name in sorted(model_categories[category]):
|
2202 |
+
config = MODEL_CONFIGS[model_name]
|
2203 |
+
is_selected = model_name == st.session_state.custom_model
|
2204 |
+
|
2205 |
+
# Create styled card for each model
|
2206 |
+
st.markdown(f"""
|
2207 |
+
<div class="model-card {'selected-model' if is_selected else ''}">
|
2208 |
+
<h4>{model_name}</h4>
|
2209 |
+
<div class="model-details">
|
2210 |
+
<p>Max Tokens: {config['max_tokens']:,}</p>
|
2211 |
+
<p>API Version: {config['api_version'] if config['api_version'] else 'Default'}</p>
|
2212 |
+
</div>
|
2213 |
+
</div>
|
2214 |
+
""", unsafe_allow_html=True)
|
2215 |
+
|
2216 |
+
# Button to select this model
|
2217 |
+
button_label = "Selected β" if is_selected else "Select Model"
|
2218 |
+
if st.button(button_label, key=f"model_{model_name}", disabled=is_selected):
|
2219 |
+
st.session_state.custom_model = model_name
|
2220 |
+
if st.session_state.ai_models and 'model_name' in st.session_state.ai_models:
|
2221 |
+
st.session_state.ai_models['model_name'] = model_name
|
2222 |
+
st.rerun()
|
2223 |
+
|
2224 |
+
# Display current model selection
|
2225 |
+
st.info(f"π€ **Currently using: {st.session_state.custom_model}**")
|
2226 |
+
|
2227 |
+
# Add a refresh button to update model connection
|
2228 |
+
if st.button("π Refresh Model Connection", key="refresh_model_connection"):
|
2229 |
+
if st.session_state.ai_models and 'client' in st.session_state.ai_models:
|
2230 |
+
try:
|
2231 |
+
# Test connection with minimal prompt
|
2232 |
+
from azure.ai.inference.models import UserMessage
|
2233 |
+
model_name = st.session_state.custom_model
|
2234 |
+
config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
|
2235 |
+
|
2236 |
+
# Use appropriate parameters based on model configuration
|
2237 |
+
api_params = {
|
2238 |
+
"messages": [UserMessage("Hello")],
|
2239 |
+
"model": model_name
|
2240 |
+
}
|
2241 |
+
api_params[config["param_name"]] = 10 # Just request 10 tokens for quick test
|
2242 |
+
|
2243 |
+
if config["api_version"]:
|
2244 |
+
# Create version-specific client if needed
|
2245 |
+
token = get_secret("github_token_api")
|
2246 |
+
from azure.ai.inference import ChatCompletionsClient
|
2247 |
+
from azure.core.credentials import AzureKeyCredential
|
2248 |
+
|
2249 |
+
client = ChatCompletionsClient(
|
2250 |
+
endpoint=st.session_state.ai_models["endpoint"],
|
2251 |
+
credential=AzureKeyCredential(token),
|
2252 |
+
api_version=config["api_version"]
|
2253 |
+
)
|
2254 |
+
response = client.complete(**api_params)
|
2255 |
+
else:
|
2256 |
+
response = st.session_state.ai_models["client"].complete(**api_params)
|
2257 |
+
|
2258 |
+
st.success(f"β
Connection to {model_name} successful!")
|
2259 |
+
st.session_state.ai_models["model_name"] = model_name
|
2260 |
+
|
2261 |
+
except Exception as e:
|
2262 |
+
st.error(f"β Connection error: {str(e)}")
|
2263 |
+
st.info("Please try the Debug Connection section to re-initialize the API connection.")
|
2264 |
|
2265 |
# AI code generation
|
2266 |
if st.session_state.ai_models and "client" in st.session_state.ai_models:
|
|
|
2300 |
client = st.session_state.ai_models["client"]
|
2301 |
model_name = st.session_state.ai_models["model_name"]
|
2302 |
|
2303 |
+
# Get configuration for this model
|
2304 |
+
config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS["default"])
|
2305 |
+
|
2306 |
# Create the prompt
|
2307 |
prompt = f"""Write a complete Manim animation scene based on this code or idea:
|
2308 |
{code_input}
|
|
|
2316 |
Here's the complete Manim code:
|
2317 |
"""
|
2318 |
|
2319 |
+
# Prepare API call parameters based on model requirements
|
2320 |
+
api_params = {
|
2321 |
+
"messages": [UserMessage(prompt)],
|
2322 |
+
"model": model_name
|
2323 |
+
}
|
2324 |
+
|
2325 |
+
# Add the appropriate token parameter
|
2326 |
+
api_params[config["param_name"]] = config["max_tokens"]
|
2327 |
+
|
2328 |
+
# Check if we need to specify API version
|
2329 |
+
if config["api_version"]:
|
2330 |
+
# If we need a specific API version, create a new client with that version
|
2331 |
+
logger.info(f"Using API version {config['api_version']} for model {model_name}")
|
2332 |
+
|
2333 |
+
# Get token from session state
|
2334 |
+
token = get_secret("github_token_api")
|
2335 |
+
if not token:
|
2336 |
+
st.error("GitHub token not found in secrets")
|
2337 |
+
return None
|
2338 |
+
|
2339 |
+
# Import required modules for creating client with specific API version
|
2340 |
+
from azure.ai.inference import ChatCompletionsClient
|
2341 |
+
from azure.core.credentials import AzureKeyCredential
|
2342 |
+
|
2343 |
+
# Create client with specific API version
|
2344 |
+
version_specific_client = ChatCompletionsClient(
|
2345 |
+
endpoint=st.session_state.ai_models["endpoint"],
|
2346 |
+
credential=AzureKeyCredential(token),
|
2347 |
+
api_version=config["api_version"]
|
2348 |
+
)
|
2349 |
+
|
2350 |
+
# Make the API call with the version-specific client
|
2351 |
+
response = version_specific_client.complete(**api_params)
|
2352 |
+
else:
|
2353 |
+
# Use the existing client
|
2354 |
+
response = client.complete(**api_params)
|
2355 |
|
2356 |
# Process the response
|
2357 |
if response and response.choices and len(response.choices) > 0:
|