Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import logging | |
import yaml | |
import os | |
import json | |
import jwt | |
import redis | |
import sqlite3 | |
from datetime import datetime, timedelta | |
from pathlib import Path | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import io | |
import traceback | |
from typing import Tuple, Optional, Dict, List, Union | |
import colorama | |
from colorama import Fore, Style | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
import plotly.graph_objects as go | |
import hashlib | |
import asyncio | |
import aiohttp | |
from fastapi import FastAPI, HTTPException, Depends, status | |
from fastapi.security import OAuth2PasswordBearer | |
from pydantic import BaseModel, EmailStr | |
import uvicorn | |
# Advanced Configuration Models | |
class ModelConfig: | |
name: str | |
max_length: int | |
num_beams: int | |
temperature: float | |
top_k: int | |
top_p: float | |
class StyleConfig: | |
node_color: str | |
edge_color: str | |
node_size: int | |
font_size: int | |
layout: str | |
class OutputConfig: | |
width: int | |
height: int | |
dpi: int | |
format: str | |
quality: int | |
# Abstract Base Classes for Extensibility | |
class DiagramStrategy(ABC): | |
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: | |
pass | |
class NetworkDiagram(DiagramStrategy): | |
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: | |
G = nx.DiGraph() | |
for i in range(len(components)-1): | |
G.add_edge(components[i], components[i+1]) | |
plt.figure(figsize=(12, 8)) | |
if style.layout == "spring": | |
pos = nx.spring_layout(G) | |
elif style.layout == "circular": | |
pos = nx.circular_layout(G) | |
else: | |
pos = nx.kamada_kawai_layout(G) | |
nx.draw_networkx_nodes(G, pos, | |
node_color=style.node_color, | |
node_size=style.node_size) | |
nx.draw_networkx_edges(G, pos, | |
edge_color=style.edge_color, | |
arrows=True) | |
nx.draw_networkx_labels(G, pos, | |
font_size=style.font_size) | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=300) | |
plt.close() | |
buf.seek(0) | |
return Image.open(buf) | |
class PlotlyDiagram(DiagramStrategy): | |
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: | |
G = nx.DiGraph() | |
for i in range(len(components)-1): | |
G.add_edge(components[i], components[i+1]) | |
pos = nx.spring_layout(G) | |
edge_x = [] | |
edge_y = [] | |
for edge in G.edges(): | |
x0, y0 = pos[edge[0]] | |
x1, y1 = pos[edge[1]] | |
edge_x.extend([x0, x1, None]) | |
edge_y.extend([y0, y1, None]) | |
node_x = [pos[node][0] for node in G.nodes()] | |
node_y = [pos[node][1] for node in G.nodes()] | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=edge_x, y=edge_y, | |
line=dict(width=0.5, color=style.edge_color), | |
hoverinfo='none', | |
mode='lines')) | |
fig.add_trace(go.Scatter(x=node_x, y=node_y, | |
mode='markers+text', | |
marker=dict(size=style.node_size/100, | |
color=style.node_color), | |
text=list(G.nodes()), | |
textposition="bottom center")) | |
fig.update_layout(showlegend=False, | |
hovermode='closest', | |
margin=dict(b=0,l=0,r=0,t=0)) | |
img_bytes = fig.to_image(format="png") # Requires kaleido | |
return Image.open(io.BytesIO(img_bytes)) | |
# Database Manager | |
class DatabaseManager: | |
def __init__(self, db_path: str = "diagrams.db"): | |
self.conn = sqlite3.connect(db_path) | |
self.create_tables() | |
def create_tables(self): | |
cursor = self.conn.cursor() | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS users ( | |
id INTEGER PRIMARY KEY, | |
username TEXT UNIQUE, | |
email TEXT UNIQUE, | |
password_hash TEXT, | |
created_at TIMESTAMP | |
) | |
''') | |
cursor.execute(''' | |
CREATE TABLE IF NOT EXISTS diagrams ( | |
id INTEGER PRIMARY KEY, | |
user_id INTEGER, | |
title TEXT, | |
description TEXT, | |
created_at TIMESTAMP, | |
image_path TEXT, | |
FOREIGN KEY (user_id) REFERENCES users (id) | |
) | |
''') | |
self.conn.commit() | |
# Cache Manager using Redis | |
class CacheManager: | |
def __init__(self, redis_url: str = "redis://localhost"): | |
self.redis_client = redis.from_url(redis_url) | |
def get_cached_diagram(self, key: str) -> Optional[bytes]: | |
return self.redis_client.get(key) | |
def cache_diagram(self, key: str, diagram: bytes, expire: int = 3600): | |
self.redis_client.set(key, diagram, ex=expire) | |
# Advanced Diagram Generator | |
class AdvancedDiagramGenerator: | |
def __init__(self, config_path: str = "config.yaml"): | |
# Initialize logging first | |
self.setup_logging() | |
# Load configuration | |
self.load_config(config_path) | |
# Setup components (tokenizer, model, etc.) | |
self.setup_components() | |
# Initialize diagram strategies | |
self.strategies = { | |
"network": NetworkDiagram(), | |
"plotly": PlotlyDiagram() | |
} | |
def load_config(self, config_path: str): | |
try: | |
if not os.path.exists(config_path): | |
self.logger.warning(f"Config file not found at {config_path}. Using default configuration.") | |
# Default configuration | |
config_data = { | |
'model': { | |
'name': 't5-small', | |
'max_length': 512, | |
'num_beams': 4, | |
'temperature': 1.0, | |
'top_k': 50, | |
'top_p': 0.9 | |
}, | |
'styles': { | |
'network': { | |
'node_color': '#1f77b4', | |
'edge_color': '#7f7f7f', | |
'node_size': 3000, | |
'font_size': 12, | |
'layout': 'spring' | |
} | |
}, | |
'output': { | |
'width': 1200, | |
'height': 800, | |
'dpi': 300, | |
'format': 'png', | |
'quality': 95 | |
} | |
} | |
else: | |
with open(config_path) as f: | |
config_data = yaml.safe_load(f) | |
# Ensure all required sections exist with defaults | |
if 'model' not in config_data: | |
config_data['model'] = {} | |
# Set default values for model configuration | |
model_config_data = config_data['model'] | |
model_config_data.setdefault('name', 't5-small') | |
model_config_data.setdefault('max_length', 512) | |
model_config_data.setdefault('num_beams', 4) | |
model_config_data.setdefault('temperature', 1.0) | |
model_config_data.setdefault('top_k', 50) | |
model_config_data.setdefault('top_p', 0.9) | |
# Create ModelConfig instance | |
self.model_config = ModelConfig(**model_config_data) | |
# Handle styles configuration | |
if 'styles' not in config_data: | |
config_data['styles'] = { | |
'network': { | |
'node_color': '#1f77b4', | |
'edge_color': '#7f7f7f', | |
'node_size': 3000, | |
'font_size': 12, | |
'layout': 'spring' | |
} | |
} | |
# Create StyleConfig instances | |
self.style_configs = {} | |
for style_name, style_data in config_data['styles'].items(): | |
style_data.setdefault('node_color', '#1f77b4') | |
style_data.setdefault('edge_color', '#7f7f7f') | |
style_data.setdefault('node_size', 3000) | |
style_data.setdefault('font_size', 12) | |
style_data.setdefault('layout', 'spring') | |
self.style_configs[style_name] = StyleConfig(**style_data) | |
# Handle output configuration | |
if 'output' not in config_data: | |
config_data['output'] = {} | |
output_config_data = config_data['output'] | |
output_config_data.setdefault('width', 1200) | |
output_config_data.setdefault('height', 800) | |
output_config_data.setdefault('dpi', 300) | |
output_config_data.setdefault('format', 'png') | |
output_config_data.setdefault('quality', 95) | |
# Create OutputConfig instance | |
self.output_config = OutputConfig(**output_config_data) | |
self.config = config_data | |
self.logger.info("Configuration loaded successfully") | |
except Exception as e: | |
self.logger.error(f"Error loading configuration: {str(e)}") | |
raise RuntimeError(f"Failed to load configuration: {str(e)}") | |
def setup_components(self): | |
# Initialize tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_config.name) | |
self.model = T5ForConditionalGeneration.from_pretrained(self.model_config.name) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
def setup_logging(self): | |
logging.basicConfig(level=logging.INFO) | |
self.logger = logging.getLogger(__name__) | |
async def extract_components(self, text: str) -> List[str]: | |
inputs = self.tokenizer( | |
f"convert to diagram: {text}", | |
return_tensors="pt", | |
max_length=self.model_config.max_length, | |
truncation=True | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs.input_ids, | |
max_length=150, | |
num_beams=self.model_config.num_beams, | |
temperature=self.model_config.temperature, | |
top_k=self.model_config.top_k, | |
top_p=self.model_config.top_p | |
) | |
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
components = [comp.strip() for comp in decoded.replace('->', ',').split(',')] | |
return [comp for comp in components if comp] | |
def save_to_database(self, user_id: int, description: str, diagram: Image.Image): | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
image_path = f"diagrams/user_{user_id}/{timestamp}.png" | |
os.makedirs(os.path.dirname(image_path), exist_ok=True) | |
diagram.save(image_path) | |
cursor = self.db.conn.cursor() | |
cursor.execute(''' | |
INSERT INTO diagrams (user_id, description, created_at, image_path) | |
VALUES (?, ?, ?, ?) | |
''', (user_id, description, datetime.now(), image_path)) | |
self.db.conn.commit() | |
async def generate_diagram(self, text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]: | |
try: | |
components = await self.extract_components(text) | |
diagram = self.strategies[strategy].create_diagram(components, self.style_configs[style]) | |
return diagram, "Diagram generated successfully!" | |
except Exception as e: | |
self.logger.error(f"Error generating diagram: {str(e)}") | |
return None, f"Error: {str(e)}" | |
# FastAPI Integration | |
app = FastAPI() | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
class DiagramRequest(BaseModel): | |
text: str | |
style: str = "network" | |
strategy: str = "network" | |
# Initialize the generator | |
generator = AdvancedDiagramGenerator() | |
# Gradio Interface | |
def create_gradio_interface(): | |
def generate(text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]: | |
return asyncio.run(generator.generate_diagram(text, style, strategy)) | |
iface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox( | |
label="Enter your diagram description", | |
placeholder="e.g., 'Create a flowchart for software development lifecycle'", | |
lines=3 | |
), | |
gr.Dropdown( | |
choices=list(generator.style_configs.keys()), | |
label="Diagram Style", | |
value="network" | |
), | |
gr.Dropdown( | |
choices=list(generator.strategies.keys()), | |
label="Visualization Strategy", | |
value="network" | |
) | |
], | |
outputs=[ | |
gr.Image(label="Generated Diagram", type="pil"), | |
gr.Textbox(label="Status") | |
], | |
title="Advanced Enterprise Diagram Generator", | |
description=""" | |
Enterprise-grade diagram generation tool with advanced features: | |
- Multiple visualization strategies | |
- Caching system | |
- Database storage | |
- Multiple output formats | |
- Custom styling options | |
""", | |
theme=gr.themes.Glass() | |
) | |
return iface | |
if __name__ == "__main__": | |
iface = create_gradio_interface() | |
iface.launch(share=True) |