Spaces:
Runtime error
Runtime error
import threading | |
import time | |
import gradio as gr | |
import logging | |
import json | |
import re | |
import torch | |
import tempfile | |
import subprocess | |
import ast | |
from pathlib import Path | |
from typing import Dict, List, Tuple, Optional, Any, Union | |
from dataclasses import dataclass, field | |
from enum import Enum | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
AutoModel | |
) | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
from PIL import Image | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(), | |
logging.FileHandler('gradio_builder.log') | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
DEFAULT_PORT = 7860 | |
MODEL_CACHE_DIR = Path("model_cache") | |
TEMPLATE_DIR = Path("templates") | |
TEMP_DIR = Path("temp") | |
# Ensure directories exist | |
for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]: | |
directory.mkdir(exist_ok=True) | |
class Template: | |
"""Template data structure""" | |
code: str | |
description: str | |
components: List[str] | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
version: str = "1.0" | |
class ComponentType(Enum): | |
"""Supported Gradio component types""" | |
IMAGE = "Image" | |
TEXTBOX = "Textbox" | |
BUTTON = "Button" | |
NUMBER = "Number" | |
MARKDOWN = "Markdown" | |
JSON = "JSON" | |
HTML = "HTML" | |
CODE = "Code" | |
DROPDOWN = "Dropdown" | |
SLIDER = "Slider" | |
CHECKBOX = "Checkbox" | |
RADIO = "Radio" | |
AUDIO = "Audio" | |
VIDEO = "Video" | |
FILE = "File" | |
DATAFRAME = "DataFrame" | |
LABEL = "Label" | |
PLOT = "Plot" | |
class ComponentConfig: | |
"""Configuration for Gradio components""" | |
type: ComponentType | |
label: str | |
properties: Dict[str, Any] = field(default_factory=dict) | |
events: List[str] = field(default_factory=list) | |
class BuilderError(Exception): | |
"""Base exception for Gradio Builder errors""" | |
pass | |
class ValidationError(BuilderError): | |
"""Raised when validation fails""" | |
pass | |
class GenerationError(BuilderError): | |
"""Raised when code generation fails""" | |
pass | |
class ModelError(BuilderError): | |
"""Raised when model operations fail""" | |
pass | |
def setup_gpu_memory(): | |
"""Configure GPU memory usage""" | |
try: | |
if torch.cuda.is_available(): | |
# Enable memory growth | |
torch.cuda.empty_cache() | |
# Set memory fraction | |
torch.cuda.set_per_process_memory_fraction(0.8) | |
logger.info("GPU memory configured successfully") | |
else: | |
logger.info("No GPU available, using CPU") | |
except Exception as e: | |
logger.warning(f"Error configuring GPU memory: {e}") | |
def validate_code(code: str) -> Tuple[bool, str]: | |
"""Validate Python code syntax""" | |
try: | |
ast.parse(code) | |
return True, "Code is valid" | |
except SyntaxError as e: | |
line_no = e.lineno | |
offset = e.offset | |
line = e.text | |
if line: | |
pointer = " " * (offset - 1) + "^" | |
error_detail = f"\nLine {line_no}:\n{line}\n{pointer}" | |
else: | |
error_detail = f" at line {line_no}" | |
return False, f"Syntax error: {str(e)}{error_detail}" | |
except Exception as e: | |
return False, f"Validation error: {str(e)}" | |
class CodeFormatter: | |
"""Handles code formatting and cleanup""" | |
def format_code(code: str) -> str: | |
"""Format code using black""" | |
try: | |
import black | |
return black.format_str(code, mode=black.FileMode()) | |
except ImportError: | |
logger.warning("black not installed, returning unformatted code") | |
return code | |
except Exception as e: | |
logger.error(f"Error formatting code: {e}") | |
return code | |
def cleanup_code(code: str) -> str: | |
"""Clean up generated code""" | |
# Remove any potential unsafe imports | |
unsafe_imports = ['os', 'subprocess', 'sys'] | |
lines = code.split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
skip = False | |
for unsafe in unsafe_imports: | |
if f"import {unsafe}" in line or f"from {unsafe}" in line: | |
skip = True | |
break | |
if not skip: | |
cleaned_lines.append(line) | |
return '\n'.join(cleaned_lines) | |
def create_temp_module(code: str) -> str: | |
"""Create a temporary module from code""" | |
try: | |
temp_file = TEMP_DIR / f"temp_module_{int(time.time())}.py" | |
with open(temp_file, "w", encoding="utf-8") as f: | |
f.write(code) | |
return str(temp_file) | |
except Exception as e: | |
raise BuilderError(f"Failed to create temporary module: {e}") | |
# Initialize GPU configuration | |
setup_gpu_memory() | |
class ModelManager: | |
"""Manages AI models and their configurations""" | |
def __init__(self, cache_dir: Path = MODEL_CACHE_DIR): | |
self.cache_dir = cache_dir | |
self.cache_dir.mkdir(exist_ok=True) | |
self.loaded_models = {} | |
self.model_configs = { | |
"code_generator": { | |
"model_id": "bigcode/starcoder", | |
"tokenizer": AutoTokenizer, | |
"model": AutoModelForCausalLM, | |
"kwargs": { | |
"torch_dtype": torch.float16, | |
"device_map": "auto", | |
"cache_dir": str(cache_dir) | |
} | |
}, | |
"image_processor": { | |
"model_id": "Salesforce/blip-image-captioning-base", | |
"processor": AutoProcessor, | |
"model": AutoModel, | |
"kwargs": { | |
"cache_dir": str(cache_dir) | |
} | |
} | |
} | |
def load_model(self, model_type: str): | |
"""Load a model by type""" | |
try: | |
if model_type not in self.model_configs: | |
raise ModelError(f"Unknown model type: {model_type}") | |
if model_type in self.loaded_models: | |
return self.loaded_models[model_type] | |
config = self.model_configs[model_type] | |
logger.info(f"Loading {model_type} model...") | |
if model_type == "code_generator": | |
tokenizer = config["tokenizer"].from_pretrained( | |
config["model_id"], | |
**config["kwargs"] | |
) | |
model = config["model"].from_pretrained( | |
config["model_id"], | |
**config["kwargs"] | |
) | |
self.loaded_models[model_type] = (model, tokenizer) | |
elif model_type == "image_processor": | |
processor = config["processor"].from_pretrained( | |
config["model_id"], | |
**config["kwargs"] | |
) | |
model = config["model"].from_pretrained( | |
config["model_id"], | |
**config["kwargs"] | |
) | |
self.loaded_models[model_type] = (model, processor) | |
logger.info(f"{model_type} model loaded successfully") | |
return self.loaded_models[model_type] | |
except Exception as e: | |
raise ModelError(f"Error loading {model_type} model: {str(e)}") | |
def unload_model(self, model_type: str): | |
"""Unload a model to free memory""" | |
if model_type in self.loaded_models: | |
del self.loaded_models[model_type] | |
torch.cuda.empty_cache() | |
logger.info(f"{model_type} model unloaded") | |
class MultimodalRAG: | |
"""Multimodal Retrieval-Augmented Generation system""" | |
def __init__(self): | |
"""Initialize the multimodal RAG system""" | |
try: | |
self.model_manager = ModelManager() | |
# Load text encoder | |
self.text_encoder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Initialize vector store | |
self.vector_store = self._initialize_vector_store() | |
# Load template database | |
self.template_embeddings = {} | |
self._initialize_template_embeddings() | |
except Exception as e: | |
raise ModelError(f"Error initializing MultimodalRAG: {str(e)}") | |
def _initialize_vector_store(self) -> faiss.IndexFlatL2: | |
"""Initialize FAISS vector store""" | |
combined_dim = 768 + 384 # BLIP (768) + text (384) | |
return faiss.IndexFlatL2(combined_dim) | |
def _initialize_template_embeddings(self): | |
"""Initialize template embeddings""" | |
try: | |
template_path = TEMPLATE_DIR / "template_embeddings.npz" | |
if template_path.exists(): | |
data = np.load(template_path) | |
self.template_embeddings = { | |
name: embedding for name, embedding in data.items() | |
} | |
except Exception as e: | |
logger.error(f"Error loading template embeddings: {e}") | |
def save_template_embeddings(self): | |
"""Save template embeddings to disk""" | |
try: | |
template_path = TEMPLATE_DIR / "template_embeddings.npz" | |
np.savez( | |
template_path, | |
**self.template_embeddings | |
) | |
except Exception as e: | |
logger.error(f"Error saving template embeddings: {e}") | |
def encode_image(self, image: Image.Image) -> np.ndarray: | |
"""Encode image using BLIP""" | |
try: | |
model, processor = self.model_manager.load_model("image_processor") | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
image_features = model.get_image_features(**inputs) | |
return image_features.detach().numpy() | |
except Exception as e: | |
raise ModelError(f"Error encoding image: {str(e)}") | |
def encode_text(self, text: str) -> np.ndarray: | |
"""Encode text using sentence-transformers""" | |
try: | |
return self.text_encoder.encode(text) | |
except Exception as e: | |
raise ModelError(f"Error encoding text: {str(e)}") | |
def generate_code(self, description: str, template_code: str) -> str: | |
"""Generate code using StarCoder""" | |
try: | |
model, tokenizer = self.model_manager.load_model("code_generator") | |
prompt = f""" | |
# Task: Generate a Gradio interface based on the description | |
# Description: {description} | |
# Base template: | |
{template_code} | |
# Generate a customized version of the template that implements the description. | |
# Only output the Python code, no explanations. | |
```python | |
""" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=2048, | |
temperature=0.2, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean and format the generated code | |
generated_code = self._clean_generated_code(generated_code) | |
return CodeFormatter.format_code(generated_code) | |
except Exception as e: | |
raise GenerationError(f"Error generating code: {str(e)}") | |
def _clean_generated_code(self, code: str) -> str: | |
"""Clean and format generated code""" | |
# Extract code between triple backticks if present | |
if "```python" in code: | |
code = code.split("```python")[1].split("```")[0] | |
elif "```" in code: | |
code = code.split("```")[1].split("```")[0] | |
code = code.strip() | |
return CodeFormatter.cleanup_code(code) | |
def find_similar_template( | |
self, | |
screenshot: Optional[Image.Image], | |
description: str | |
) -> Tuple[str, Template]: | |
"""Find most similar template based on image and description""" | |
try: | |
# Get embeddings | |
text_embedding = self.encode_text(description) | |
if screenshot: | |
img_embedding = self.encode_image(screenshot) | |
query_embedding = np.concatenate([ | |
img_embedding.flatten(), | |
text_embedding | |
]) | |
else: | |
# If no image, duplicate text embedding to match dimensions | |
query_embedding = np.concatenate([ | |
text_embedding, | |
text_embedding | |
]) | |
# Search in vector store | |
D, I = self.vector_store.search( | |
np.array([query_embedding]), | |
k=1 | |
) | |
# Get template name from index | |
template_names = list(self.template_embeddings.keys()) | |
template_name = template_names[I[0][0]] | |
# Load template | |
template_path = TEMPLATE_DIR / f"{template_name}.json" | |
with open(template_path, 'r') as f: | |
template_data = json.load(f) | |
template = Template(**template_data) | |
return template_name, template | |
except Exception as e: | |
raise ModelError(f"Error finding similar template: {str(e)}") | |
def generate_interface( | |
self, | |
screenshot: Optional[Image.Image], | |
description: str | |
) -> str: | |
"""Generate complete interface based on input""" | |
try: | |
# Find similar template | |
template_name, template = self.find_similar_template( | |
screenshot, | |
description | |
) | |
# Generate customized code | |
custom_code = self.generate_code( | |
description, | |
template.code | |
) | |
return custom_code | |
except Exception as e: | |
raise GenerationError(f"Error generating interface: {str(e)}") | |
def cleanup(self): | |
"""Cleanup resources""" | |
try: | |
# Save template embeddings | |
self.save_template_embeddings() | |
# Unload models | |
self.model_manager.unload_model("code_generator") | |
self.model_manager.unload_model("image_processor") | |
# Clear CUDA cache | |
torch.cuda.empty_cache() | |
except Exception as e: | |
logger.error(f"Error during cleanup: {e}") | |
class TemplateManager: | |
"""Manages Gradio interface templates""" | |
def __init__(self, template_dir: Path = TEMPLATE_DIR): | |
self.template_dir = template_dir | |
self.template_dir.mkdir(exist_ok=True) | |
self.templates: Dict[str, Template] = {} | |
self.load_templates() | |
def load_templates(self): | |
"""Load all templates from directory""" | |
try: | |
# Load built-in templates | |
self.templates.update(self._get_builtin_templates()) | |
# Load custom templates | |
for template_file in self.template_dir.glob("*.json"): | |
try: | |
with open(template_file, 'r', encoding='utf-8') as f: | |
template_data = json.load(f) | |
name = template_file.stem | |
self.templates[name] = Template(**template_data) | |
except Exception as e: | |
logger.error(f"Error loading template {template_file}: {e}") | |
except Exception as e: | |
logger.error(f"Error loading templates: {e}") | |
def _get_builtin_templates(self) -> Dict[str, Template]: | |
"""Get built-in templates""" | |
return { | |
"image_classifier": Template( | |
code=""" | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
def classify_image(image): | |
if image is None: | |
return {"error": 1.0} | |
# Add classification logic here | |
return {"class1": 0.8, "class2": 0.2} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Image Classifier") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
classify_btn = gr.Button("Classify") | |
with gr.Column(): | |
output_labels = gr.Label() | |
classify_btn.click( | |
fn=classify_image, | |
inputs=input_image, | |
outputs=output_labels | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
""", | |
description="Basic image classification interface", | |
components=["Image", "Button", "Label"], | |
metadata={"category": "computer_vision"} | |
), | |
"text_analyzer": Template( | |
code=""" | |
import gradio as gr | |
import numpy as np | |
def analyze_text(text, options): | |
if not text: | |
return "Please enter some text" | |
results = [] | |
if "word_count" in options: | |
results.append(f"Word count: {len(text.split())}") | |
if "char_count" in options: | |
results.append(f"Character count: {len(text)}") | |
if "sentiment" in options: | |
# Add sentiment analysis logic here | |
results.append("Sentiment: Neutral") | |
return "\\n".join(results) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Text Analysis Tool") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Enter text to analyze...", | |
lines=5 | |
) | |
options = gr.CheckboxGroup( | |
choices=["word_count", "char_count", "sentiment"], | |
label="Analysis Options", | |
value=["word_count"] | |
) | |
analyze_btn = gr.Button("Analyze") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="Analysis Results", | |
lines=5 | |
) | |
analyze_btn.click( | |
fn=analyze_text, | |
inputs=[input_text, options], | |
outputs=output_text | |
) | |
if __name__ == "__main__": | |
demo.launch() | |
""", | |
description="Text analysis interface with multiple options", | |
components=["Textbox", "CheckboxGroup", "Button"], | |
metadata={"category": "nlp"} | |
) | |
} | |
def save_template(self, name: str, template: Template) -> bool: | |
"""Save new template""" | |
try: | |
template_path = self.template_dir / f"{name}.json" | |
template_dict = { | |
"code": template.code, | |
"description": template.description, | |
"components": template.components, | |
"metadata": template.metadata, | |
"version": template.version | |
} | |
with open(template_path, 'w', encoding='utf-8') as f: | |
json.dump(template_dict, f, indent=4) | |
self.templates[name] = template | |
return True | |
except Exception as e: | |
logger.error(f"Error saving template {name}: {e}") | |
return False | |
def get_template(self, name: str) -> Optional[Template]: | |
"""Get template by name""" | |
return self.templates.get(name) | |
def list_templates(self, category: Optional[str] = None) -> List[Dict[str, Any]]: | |
"""List all available templates with optional category filter""" | |
templates_list = [] | |
for name, template in self.templates.items(): | |
if category and template.metadata.get("category") != category: | |
continue | |
templates_list.append({ | |
"name": name, | |
"description": template.description, | |
"components": template.components, | |
"category": template.metadata.get("category", "general") | |
}) | |
return templates_list | |
class InterfaceAnalyzer: | |
"""Analyzes Gradio interfaces""" | |
def extract_components(code: str) -> List[ComponentConfig]: | |
"""Extract components from code""" | |
components = [] | |
try: | |
tree = ast.parse(code) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Call): | |
if isinstance(node.func, ast.Attribute): | |
if hasattr(node.func.value, 'id') and node.func.value.id == 'gr': | |
component_type = node.func.attr | |
if hasattr(ComponentType, component_type.upper()): | |
# Extract component properties | |
properties = {} | |
label = None | |
events = [] | |
# Get properties from keywords | |
for keyword in node.keywords: | |
if keyword.arg == 'label': | |
try: | |
label = ast.literal_eval(keyword.value) | |
except: | |
label = None | |
else: | |
try: | |
properties[keyword.arg] = ast.literal_eval(keyword.value) | |
except: | |
properties[keyword.arg] = None | |
# Look for event handlers | |
parent = InterfaceAnalyzer._find_parent_assign(tree, node) | |
if parent: | |
events = InterfaceAnalyzer._find_component_events(tree, parent) | |
components.append(ComponentConfig( | |
type=ComponentType[component_type.upper()], | |
label=label or component_type, | |
properties=properties, | |
events=events | |
)) | |
except Exception as e: | |
logger.error(f"Error extracting components: {e}") | |
return components | |
def _find_parent_assign(tree: ast.AST, node: ast.Call) -> Optional[ast.AST]: | |
"""Find the assignment node for a component""" | |
for potential_parent in ast.walk(tree): | |
if isinstance(potential_parent, ast.Assign): | |
for child in ast.walk(potential_parent.value): | |
if child == node: | |
return potential_parent | |
return None | |
def _find_component_events(tree: ast.AST, assign_node: ast.Assign) -> List[str]: | |
"""Find events attached to a component""" | |
events = [] | |
component_name = assign_node.targets[0].id | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Call): | |
if isinstance(node.func, ast.Attribute): | |
if hasattr(node.func.value, 'id') and node.func.value.id == component_name: | |
events.append(node.func.attr) | |
return events | |
def analyze_interface_structure(code: str) -> Dict[str, Any]: | |
"""Analyze interface structure""" | |
try: | |
# Extract components | |
components = InterfaceAnalyzer.extract_components(code) | |
# Analyze functions | |
functions = [] | |
tree = ast.parse(code) | |
for node in ast.walk(tree): | |
if isinstance(node, ast.FunctionDef): | |
functions.append({ | |
"name": node.name, | |
"args": [arg.arg for arg in node.args.args], | |
"returns": InterfaceAnalyzer._get_return_type(node) | |
}) | |
# Analyze dependencies | |
dependencies = set() | |
for node in ast.walk(tree): | |
if isinstance(node, ast.Import): | |
for name in node.names: | |
dependencies.add(name.name) | |
elif isinstance(node, ast.ImportFrom): | |
if node.module: | |
dependencies.add(node.module) | |
return { | |
"components": [ | |
{ | |
"type": comp.type.value, | |
"label": comp.label, | |
"properties": comp.properties, | |
"events": comp.events | |
} | |
for comp in components | |
], | |
"functions": functions, | |
"dependencies": list(dependencies) | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing interface: {e}") | |
return {} | |
def _get_return_type(node: ast.FunctionDef) -> str: | |
"""Get function return type if specified""" | |
if node.returns: | |
return ast.unparse(node.returns) | |
return "Any" | |
class PreviewManager: | |
"""Manages interface previews""" | |
def __init__(self): | |
self.current_process: Optional[subprocess.Popen] = None | |
self.preview_port = DEFAULT_PORT | |
self._lock = threading.Lock() | |
def start_preview(self, code: str) -> Tuple[bool, str]: | |
"""Start preview in a separate process""" | |
with self._lock: | |
try: | |
self.stop_preview() | |
# Create temporary module | |
module_path = create_temp_module(code) | |
# Start new process | |
self.current_process = subprocess.Popen( | |
['python', module_path], | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE | |
) | |
# Wait for server to start | |
time.sleep(2) | |
# Check if process is still running | |
if self.current_process.poll() is not None: | |
stdout, stderr = self.current_process.communicate() | |
error_msg = stderr.decode('utf-8') | |
raise RuntimeError(f"Preview failed to start: {error_msg}") | |
return True, f"http://localhost:{self.preview_port}" | |
except Exception as e: | |
return False, str(e) | |
def stop_preview(self): | |
"""Stop current preview process""" | |
if self.current_process: | |
self.current_process.terminate() | |
try: | |
self.current_process.wait(timeout=5) | |
except subprocess.TimeoutExpired: | |
self.current_process.kill() | |
self.current_process = None | |
def cleanup(self): | |
"""Cleanup resources""" | |
self.stop_preview() | |
# Clean up temporary files | |
for temp_file in TEMP_DIR.glob("*.py"): | |
try: | |
temp_file.unlink() | |
except Exception as e: | |
logger.error(f"Error deleting temporary file {temp_file}: {e}") | |
class GradioInterface: | |
"""Main Gradio interface builder class""" | |
def __init__(self): | |
"""Initialize the Gradio interface builder""" | |
try: | |
self.rag_system = MultimodalRAG() | |
self.template_manager = TemplateManager() | |
self.preview_manager = PreviewManager() | |
self.current_code = "" | |
self.error_log = [] | |
self.interface = self._create_interface() | |
except Exception as e: | |
logger.error(f"Error initializing GradioInterface: {str(e)}") | |
raise | |
def _create_interface(self) -> gr.Blocks: | |
"""Create the main Gradio interface""" | |
with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
gr.Markdown("# 🚀 Gradio Interface Builder") | |
with gr.Tabs(): | |
# Design Tab | |
with gr.Tab("Design"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Input Section | |
gr.Markdown("## 📝 Design Your Interface") | |
description = gr.Textbox( | |
label="Description", | |
placeholder="Describe the interface you want to create...", | |
lines=3 | |
) | |
screenshot = gr.Image( | |
label="Screenshot (optional)", | |
type="pil" | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("🎨 Generate Interface", variant="primary") | |
clear_btn = gr.Button("🗑️ Clear") | |
# Template Selection | |
gr.Markdown("### 📚 Templates") | |
template_dropdown = gr.Dropdown( | |
choices=self._get_template_choices(), | |
label="Base Template", | |
interactive=True | |
) | |
with gr.Column(scale=3): | |
# Code Editor | |
code_editor = gr.Code( | |
label="Generated Code", | |
language="python", | |
interactive=True | |
) | |
with gr.Row(): | |
validate_btn = gr.Button("✅ Validate") | |
format_btn = gr.Button("📋 Format") | |
save_template_btn = gr.Button("💾 Save as Template") | |
validation_output = gr.Markdown() | |
# Preview Tab | |
with gr.Tab("Preview"): | |
with gr.Row(): | |
preview_btn = gr.Button("▶️ Start Preview", variant="primary") | |
stop_preview_btn = gr.Button("⏹️ Stop Preview") | |
preview_frame = gr.HTML( | |
label="Preview", | |
value="<p>Click 'Start Preview' to see your interface</p>" | |
) | |
preview_status = gr.Markdown() | |
# Analysis Tab | |
with gr.Tab("Analysis"): | |
analyze_btn = gr.Button("🔍 Analyze Interface") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 🧩 Components") | |
components_json = gr.JSON(label="Detected Components") | |
with gr.Column(): | |
gr.Markdown("### 🔄 Functions") | |
functions_json = gr.JSON(label="Interface Functions") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### 📦 Dependencies") | |
dependencies_json = gr.JSON(label="Required Dependencies") | |
with gr.Column(): | |
gr.Markdown("### 📄 Requirements") | |
requirements_text = gr.Textbox( | |
label="requirements.txt", | |
lines=10 | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=self._generate_interface, | |
inputs=[description, screenshot, template_dropdown], | |
outputs=[code_editor, validation_output] | |
) | |
clear_btn.click( | |
fn=self._clear_interface, | |
outputs=[description, screenshot, code_editor, validation_output] | |
) | |
validate_btn.click( | |
fn=self._validate_code, | |
inputs=[code_editor], | |
outputs=[validation_output] | |
) | |
format_btn.click( | |
fn=self._format_code, | |
inputs=[code_editor], | |
outputs=[code_editor] | |
) | |
save_template_btn.click( | |
fn=self._save_as_template, | |
inputs=[code_editor, description], | |
outputs=[template_dropdown, validation_output] | |
) | |
preview_btn.click( | |
fn=self._start_preview, | |
inputs=[code_editor], | |
outputs=[preview_frame, preview_status] | |
) | |
stop_preview_btn.click( | |
fn=self._stop_preview, | |
outputs=[preview_frame, preview_status] | |
) | |
analyze_btn.click( | |
fn=self._analyze_interface, | |
inputs=[code_editor], | |
outputs=[ | |
components_json, | |
functions_json, | |
dependencies_json, | |
requirements_text | |
] | |
) | |
# Update template dropdown when templates change | |
template_dropdown.change( | |
fn=self._load_template, | |
inputs=[template_dropdown], | |
outputs=[code_editor] | |
) | |
return interface | |
def _get_template_choices(self) -> List[str]: | |
"""Get list of available templates""" | |
templates = self.template_manager.list_templates() | |
return [""] + [t["name"] for t in templates] | |
def _generate_interface( | |
self, | |
description: str, | |
screenshot: Optional[Image.Image], | |
template_name: str | |
) -> Tuple[str, str]: | |
"""Generate interface code""" | |
try: | |
if template_name: | |
template = self.template_manager.get_template(template_name) | |
if template: | |
code = self.rag_system.generate_code(description, template.code) | |
else: | |
raise ValueError(f"Template {template_name} not found") | |
else: | |
code = self.rag_system.generate_interface(screenshot, description) | |
self.current_code = code | |
return code, "✅ Code generated successfully" | |
except Exception as e: | |
error_msg = f"❌ Error generating interface: {str(e)}" | |
logger.error(error_msg) | |
return "", error_msg | |
def _clear_interface(self) -> Tuple[str, None, str, str]: | |
"""Clear all inputs and outputs""" | |
self.current_code = "" | |
return "", None, "", "" | |
def _validate_code(self, code: str) -> str: | |
"""Validate code syntax""" | |
is_valid, message = validate_code(code) | |
return f"{'✅' if is_valid else '❌'} {message}" | |
def _format_code(self, code: str) -> str: | |
"""Format code""" | |
try: | |
return CodeFormatter.format_code(code) | |
except Exception as e: | |
logger.error(f"Error formatting code: {e}") | |
return code | |
def _save_as_template(self, code: str, description: str) -> Tuple[List[str], str]: | |
"""Save current code as template""" | |
try: | |
# Generate template name | |
base_name = "custom_template" | |
counter = 1 | |
name = base_name | |
while self.template_manager.get_template(name): | |
name = f"{base_name}_{counter}" | |
counter += 1 | |
# Create template | |
template = Template( | |
code=code, | |
description=description, | |
components=InterfaceAnalyzer.extract_components(code), | |
metadata={"category": "custom"} | |
) | |
# Save template | |
if self.template_manager.save_template(name, template): | |
return self._get_template_choices(), f"✅ Template saved as {name}" | |
else: | |
raise Exception("Failed to save template") | |
except Exception as e: | |
error_msg = f"❌ Error saving template: {str(e)}" | |
logger.error(error_msg) | |
return self._get_template_choices(), error_msg | |
def _start_preview(self, code: str) -> Tuple[str, str]: | |
"""Start interface preview""" | |
success, result = self.preview_manager.start_preview(code) | |
if success: | |
return f'<iframe src="{result}" width="100%" height="600px"></iframe>', "✅ Preview started" | |
else: | |
return "", f"❌ Preview failed: {result}" | |
def _stop_preview(self) -> Tuple[str, str]: | |
"""Stop interface preview""" | |
self.preview_manager.stop_preview() | |
return "<p>Preview stopped</p>", "✅ Preview stopped" | |
def _load_template(self, template_name: str) -> str: | |
"""Load selected template""" | |
if not template_name: | |
return "" | |
template = self.template_manager.get_template(template_name) | |
if template: | |
return template.code | |
return "" | |
def _analyze_interface(self, code: str) -> Tuple[Dict, Dict, Dict, str]: | |
"""Analyze interface structure""" | |
try: | |
analysis = InterfaceAnalyzer.analyze_interface_structure(code) | |
# Generate requirements.txt | |
dependencies = analysis.get("dependencies", []) | |
requirements = CodeGenerator.generate_requirements(dependencies) | |
return ( | |
analysis.get("components", {}), | |
analysis.get("functions", {}), | |
{"dependencies": dependencies}, | |
requirements | |
) | |
except Exception as e: | |
logger.error(f"Error analyzing interface: {e}") | |
return {}, {}, {}, "" | |
def launch(self, **kwargs): | |
"""Launch the interface""" | |
try: | |
self.interface.launch(**kwargs) | |
finally: | |
self.cleanup() | |
def cleanup(self): | |
"""Cleanup resources""" | |
try: | |
self.preview_manager.cleanup() | |
self.rag_system.cleanup() | |
except Exception as e: | |
logger.error(f"Error during cleanup: {e}") | |
def main(): | |
"""Main entry point""" | |
try: | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
# Create and launch interface | |
interface = GradioInterface() | |
interface.launch( | |
share=True, | |
debug=True, | |
server_name="0.0.0.0" | |
) | |
except Exception as e: | |
logger.error(f"Application error: {e}") | |
raise | |
if __name__ == "__main__": | |
main() |