|
import gradio as gr |
|
import torch |
|
import logging |
|
import yaml |
|
import os |
|
import json |
|
import jwt |
|
import redis |
|
import sqlite3 |
|
from datetime import datetime, timedelta |
|
from pathlib import Path |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
import traceback |
|
from typing import Tuple, Optional, Dict, List, Union |
|
import colorama |
|
from colorama import Fore, Style |
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
import plotly.graph_objects as go |
|
import hashlib |
|
import asyncio |
|
import aiohttp |
|
from fastapi import FastAPI, HTTPException, Depends, status |
|
from fastapi.security import OAuth2PasswordBearer |
|
from pydantic import BaseModel, EmailStr |
|
import uvicorn |
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
name: str |
|
max_length: int |
|
num_beams: int |
|
temperature: float |
|
top_k: int |
|
top_p: float |
|
|
|
@dataclass |
|
class StyleConfig: |
|
node_color: str |
|
edge_color: str |
|
node_size: int |
|
font_size: int |
|
layout: str |
|
|
|
@dataclass |
|
class OutputConfig: |
|
width: int |
|
height: int |
|
dpi: int |
|
format: str |
|
quality: int |
|
|
|
|
|
class DiagramStrategy(ABC): |
|
@abstractmethod |
|
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: |
|
pass |
|
|
|
class NetworkDiagram(DiagramStrategy): |
|
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: |
|
G = nx.DiGraph() |
|
for i in range(len(components)-1): |
|
G.add_edge(components[i], components[i+1]) |
|
|
|
plt.figure(figsize=(12, 8)) |
|
|
|
if style.layout == "spring": |
|
pos = nx.spring_layout(G) |
|
elif style.layout == "circular": |
|
pos = nx.circular_layout(G) |
|
else: |
|
pos = nx.kamada_kawai_layout(G) |
|
|
|
nx.draw_networkx_nodes(G, pos, |
|
node_color=style.node_color, |
|
node_size=style.node_size) |
|
nx.draw_networkx_edges(G, pos, |
|
edge_color=style.edge_color, |
|
arrows=True) |
|
nx.draw_networkx_labels(G, pos, |
|
font_size=style.font_size) |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', dpi=300) |
|
plt.close() |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
class PlotlyDiagram(DiagramStrategy): |
|
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image: |
|
G = nx.DiGraph() |
|
for i in range(len(components)-1): |
|
G.add_edge(components[i], components[i+1]) |
|
|
|
pos = nx.spring_layout(G) |
|
|
|
edge_x = [] |
|
edge_y = [] |
|
for edge in G.edges(): |
|
x0, y0 = pos[edge[0]] |
|
x1, y1 = pos[edge[1]] |
|
edge_x.extend([x0, x1, None]) |
|
edge_y.extend([y0, y1, None]) |
|
|
|
node_x = [pos[node][0] for node in G.nodes()] |
|
node_y = [pos[node][1] for node in G.nodes()] |
|
|
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=edge_x, y=edge_y, |
|
line=dict(width=0.5, color=style.edge_color), |
|
hoverinfo='none', |
|
mode='lines')) |
|
|
|
fig.add_trace(go.Scatter(x=node_x, y=node_y, |
|
mode='markers+text', |
|
marker=dict(size=style.node_size/100, |
|
color=style.node_color), |
|
text=list(G.nodes()), |
|
textposition="bottom center")) |
|
|
|
fig.update_layout(showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=0,l=0,r=0,t=0)) |
|
|
|
img_bytes = fig.to_image(format="png") |
|
return Image.open(io.BytesIO(img_bytes)) |
|
|
|
|
|
class DatabaseManager: |
|
def __init__(self, db_path: str = "diagrams.db"): |
|
self.conn = sqlite3.connect(db_path) |
|
self.create_tables() |
|
|
|
def create_tables(self): |
|
cursor = self.conn.cursor() |
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS users ( |
|
id INTEGER PRIMARY KEY, |
|
username TEXT UNIQUE, |
|
email TEXT UNIQUE, |
|
password_hash TEXT, |
|
created_at TIMESTAMP |
|
) |
|
''') |
|
|
|
cursor.execute(''' |
|
CREATE TABLE IF NOT EXISTS diagrams ( |
|
id INTEGER PRIMARY KEY, |
|
user_id INTEGER, |
|
title TEXT, |
|
description TEXT, |
|
created_at TIMESTAMP, |
|
image_path TEXT, |
|
FOREIGN KEY (user_id) REFERENCES users (id) |
|
) |
|
''') |
|
self.conn.commit() |
|
|
|
|
|
class CacheManager: |
|
def __init__(self, redis_url: str = "redis://localhost"): |
|
self.redis_client = redis.from_url(redis_url) |
|
|
|
def get_cached_diagram(self, key: str) -> Optional[bytes]: |
|
return self.redis_client.get(key) |
|
|
|
def cache_diagram(self, key: str, diagram: bytes, expire: int = 3600): |
|
self.redis_client.set(key, diagram, ex=expire) |
|
|
|
|
|
class AdvancedDiagramGenerator: |
|
def __init__(self, config_path: str = "config.yaml"): |
|
|
|
self.setup_logging() |
|
|
|
|
|
self.load_config(config_path) |
|
|
|
|
|
self.setup_components() |
|
|
|
|
|
self.strategies = { |
|
"network": NetworkDiagram(), |
|
"plotly": PlotlyDiagram() |
|
} |
|
|
|
def load_config(self, config_path: str): |
|
try: |
|
if not os.path.exists(config_path): |
|
self.logger.warning(f"Config file not found at {config_path}. Using default configuration.") |
|
|
|
config_data = { |
|
'model': { |
|
'name': 't5-small', |
|
'max_length': 512, |
|
'num_beams': 4, |
|
'temperature': 1.0, |
|
'top_k': 50, |
|
'top_p': 0.9 |
|
}, |
|
'styles': { |
|
'network': { |
|
'node_color': '#1f77b4', |
|
'edge_color': '#7f7f7f', |
|
'node_size': 3000, |
|
'font_size': 12, |
|
'layout': 'spring' |
|
} |
|
}, |
|
'output': { |
|
'width': 1200, |
|
'height': 800, |
|
'dpi': 300, |
|
'format': 'png', |
|
'quality': 95 |
|
} |
|
} |
|
else: |
|
with open(config_path) as f: |
|
config_data = yaml.safe_load(f) |
|
|
|
|
|
if 'model' not in config_data: |
|
config_data['model'] = {} |
|
|
|
|
|
model_config_data = config_data['model'] |
|
model_config_data.setdefault('name', 't5-small') |
|
model_config_data.setdefault('max_length', 512) |
|
model_config_data.setdefault('num_beams', 4) |
|
model_config_data.setdefault('temperature', 1.0) |
|
model_config_data.setdefault('top_k', 50) |
|
model_config_data.setdefault('top_p', 0.9) |
|
|
|
|
|
self.model_config = ModelConfig(**model_config_data) |
|
|
|
|
|
if 'styles' not in config_data: |
|
config_data['styles'] = { |
|
'network': { |
|
'node_color': '#1f77b4', |
|
'edge_color': '#7f7f7f', |
|
'node_size': 3000, |
|
'font_size': 12, |
|
'layout': 'spring' |
|
} |
|
} |
|
|
|
|
|
self.style_configs = {} |
|
for style_name, style_data in config_data['styles'].items(): |
|
style_data.setdefault('node_color', '#1f77b4') |
|
style_data.setdefault('edge_color', '#7f7f7f') |
|
style_data.setdefault('node_size', 3000) |
|
style_data.setdefault('font_size', 12) |
|
style_data.setdefault('layout', 'spring') |
|
self.style_configs[style_name] = StyleConfig(**style_data) |
|
|
|
|
|
if 'output' not in config_data: |
|
config_data['output'] = {} |
|
|
|
output_config_data = config_data['output'] |
|
output_config_data.setdefault('width', 1200) |
|
output_config_data.setdefault('height', 800) |
|
output_config_data.setdefault('dpi', 300) |
|
output_config_data.setdefault('format', 'png') |
|
output_config_data.setdefault('quality', 95) |
|
|
|
|
|
self.output_config = OutputConfig(**output_config_data) |
|
|
|
self.config = config_data |
|
self.logger.info("Configuration loaded successfully") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error loading configuration: {str(e)}") |
|
raise RuntimeError(f"Failed to load configuration: {str(e)}") |
|
|
|
def setup_components(self): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_config.name) |
|
self.model = T5ForConditionalGeneration.from_pretrained(self.model_config.name) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def setup_logging(self): |
|
logging.basicConfig(level=logging.INFO) |
|
self.logger = logging.getLogger(__name__) |
|
|
|
async def extract_components(self, text: str) -> List[str]: |
|
inputs = self.tokenizer( |
|
f"convert to diagram: {text}", |
|
return_tensors="pt", |
|
max_length=self.model_config.max_length, |
|
truncation=True |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs.input_ids, |
|
max_length=150, |
|
num_beams=self.model_config.num_beams, |
|
temperature=self.model_config.temperature, |
|
top_k=self.model_config.top_k, |
|
top_p=self.model_config.top_p |
|
) |
|
|
|
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
components = [comp.strip() for comp in decoded.replace('->', ',').split(',')] |
|
return [comp for comp in components if comp] |
|
|
|
def save_to_database(self, user_id: int, description: str, diagram: Image.Image): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
image_path = f"diagrams/user_{user_id}/{timestamp}.png" |
|
os.makedirs(os.path.dirname(image_path), exist_ok=True) |
|
diagram.save(image_path) |
|
|
|
cursor = self.db.conn.cursor() |
|
cursor.execute(''' |
|
INSERT INTO diagrams (user_id, description, created_at, image_path) |
|
VALUES (?, ?, ?, ?) |
|
''', (user_id, description, datetime.now(), image_path)) |
|
self.db.conn.commit() |
|
|
|
async def generate_diagram(self, text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]: |
|
try: |
|
components = await self.extract_components(text) |
|
diagram = self.strategies[strategy].create_diagram(components, self.style_configs[style]) |
|
return diagram, "Diagram generated successfully!" |
|
except Exception as e: |
|
self.logger.error(f"Error generating diagram: {str(e)}") |
|
return None, f"Error: {str(e)}" |
|
|
|
|
|
app = FastAPI() |
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
class DiagramRequest(BaseModel): |
|
text: str |
|
style: str = "network" |
|
strategy: str = "network" |
|
|
|
|
|
generator = AdvancedDiagramGenerator() |
|
|
|
|
|
def create_gradio_interface(): |
|
def generate(text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]: |
|
return asyncio.run(generator.generate_diagram(text, style, strategy)) |
|
|
|
iface = gr.Interface( |
|
fn=generate, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Enter your diagram description", |
|
placeholder="e.g., 'Create a flowchart for software development lifecycle'", |
|
lines=3 |
|
), |
|
gr.Dropdown( |
|
choices=list(generator.style_configs.keys()), |
|
label="Diagram Style", |
|
value="network" |
|
), |
|
gr.Dropdown( |
|
choices=list(generator.strategies.keys()), |
|
label="Visualization Strategy", |
|
value="network" |
|
) |
|
], |
|
outputs=[ |
|
gr.Image(label="Generated Diagram", type="pil"), |
|
gr.Textbox(label="Status") |
|
], |
|
title="Advanced Enterprise Diagram Generator", |
|
description=""" |
|
Enterprise-grade diagram generation tool with advanced features: |
|
- Multiple visualization strategies |
|
- Caching system |
|
- Database storage |
|
- Multiple output formats |
|
- Custom styling options |
|
""", |
|
theme=gr.themes.Glass() |
|
) |
|
return iface |
|
|
|
if __name__ == "__main__": |
|
iface = create_gradio_interface() |
|
iface.launch(share=True) |