Anupam251272's picture
Create app.py
29ae6c8 verified
raw
history blame
13.8 kB
import gradio as gr
import torch
import logging
import yaml
import os
import json
import jwt
import redis
import sqlite3
from datetime import datetime, timedelta
from pathlib import Path
from transformers import AutoTokenizer, T5ForConditionalGeneration
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io
import traceback
from typing import Tuple, Optional, Dict, List, Union
import colorama
from colorama import Fore, Style
from abc import ABC, abstractmethod
from dataclasses import dataclass
import plotly.graph_objects as go
import hashlib
import asyncio
import aiohttp
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel, EmailStr
import uvicorn
# Advanced Configuration Models
@dataclass
class ModelConfig:
name: str
max_length: int
num_beams: int
temperature: float
top_k: int
top_p: float
@dataclass
class StyleConfig:
node_color: str
edge_color: str
node_size: int
font_size: int
layout: str
@dataclass
class OutputConfig:
width: int
height: int
dpi: int
format: str
quality: int
# Abstract Base Classes for Extensibility
class DiagramStrategy(ABC):
@abstractmethod
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
pass
class NetworkDiagram(DiagramStrategy):
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
G = nx.DiGraph()
for i in range(len(components)-1):
G.add_edge(components[i], components[i+1])
plt.figure(figsize=(12, 8))
if style.layout == "spring":
pos = nx.spring_layout(G)
elif style.layout == "circular":
pos = nx.circular_layout(G)
else:
pos = nx.kamada_kawai_layout(G)
nx.draw_networkx_nodes(G, pos,
node_color=style.node_color,
node_size=style.node_size)
nx.draw_networkx_edges(G, pos,
edge_color=style.edge_color,
arrows=True)
nx.draw_networkx_labels(G, pos,
font_size=style.font_size)
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=300)
plt.close()
buf.seek(0)
return Image.open(buf)
class PlotlyDiagram(DiagramStrategy):
def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
G = nx.DiGraph()
for i in range(len(components)-1):
G.add_edge(components[i], components[i+1])
pos = nx.spring_layout(G)
edge_x = []
edge_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
node_x = [pos[node][0] for node in G.nodes()]
node_y = [pos[node][1] for node in G.nodes()]
fig = go.Figure()
fig.add_trace(go.Scatter(x=edge_x, y=edge_y,
line=dict(width=0.5, color=style.edge_color),
hoverinfo='none',
mode='lines'))
fig.add_trace(go.Scatter(x=node_x, y=node_y,
mode='markers+text',
marker=dict(size=style.node_size/100,
color=style.node_color),
text=list(G.nodes()),
textposition="bottom center"))
fig.update_layout(showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0))
img_bytes = fig.to_image(format="png") # Requires kaleido
return Image.open(io.BytesIO(img_bytes))
# Database Manager
class DatabaseManager:
def __init__(self, db_path: str = "diagrams.db"):
self.conn = sqlite3.connect(db_path)
self.create_tables()
def create_tables(self):
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
username TEXT UNIQUE,
email TEXT UNIQUE,
password_hash TEXT,
created_at TIMESTAMP
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS diagrams (
id INTEGER PRIMARY KEY,
user_id INTEGER,
title TEXT,
description TEXT,
created_at TIMESTAMP,
image_path TEXT,
FOREIGN KEY (user_id) REFERENCES users (id)
)
''')
self.conn.commit()
# Cache Manager using Redis
class CacheManager:
def __init__(self, redis_url: str = "redis://localhost"):
self.redis_client = redis.from_url(redis_url)
def get_cached_diagram(self, key: str) -> Optional[bytes]:
return self.redis_client.get(key)
def cache_diagram(self, key: str, diagram: bytes, expire: int = 3600):
self.redis_client.set(key, diagram, ex=expire)
# Advanced Diagram Generator
class AdvancedDiagramGenerator:
def __init__(self, config_path: str = "config.yaml"):
# Initialize logging first
self.setup_logging()
# Load configuration
self.load_config(config_path)
# Setup components (tokenizer, model, etc.)
self.setup_components()
# Initialize diagram strategies
self.strategies = {
"network": NetworkDiagram(),
"plotly": PlotlyDiagram()
}
def load_config(self, config_path: str):
try:
if not os.path.exists(config_path):
self.logger.warning(f"Config file not found at {config_path}. Using default configuration.")
# Default configuration
config_data = {
'model': {
'name': 't5-small',
'max_length': 512,
'num_beams': 4,
'temperature': 1.0,
'top_k': 50,
'top_p': 0.9
},
'styles': {
'network': {
'node_color': '#1f77b4',
'edge_color': '#7f7f7f',
'node_size': 3000,
'font_size': 12,
'layout': 'spring'
}
},
'output': {
'width': 1200,
'height': 800,
'dpi': 300,
'format': 'png',
'quality': 95
}
}
else:
with open(config_path) as f:
config_data = yaml.safe_load(f)
# Ensure all required sections exist with defaults
if 'model' not in config_data:
config_data['model'] = {}
# Set default values for model configuration
model_config_data = config_data['model']
model_config_data.setdefault('name', 't5-small')
model_config_data.setdefault('max_length', 512)
model_config_data.setdefault('num_beams', 4)
model_config_data.setdefault('temperature', 1.0)
model_config_data.setdefault('top_k', 50)
model_config_data.setdefault('top_p', 0.9)
# Create ModelConfig instance
self.model_config = ModelConfig(**model_config_data)
# Handle styles configuration
if 'styles' not in config_data:
config_data['styles'] = {
'network': {
'node_color': '#1f77b4',
'edge_color': '#7f7f7f',
'node_size': 3000,
'font_size': 12,
'layout': 'spring'
}
}
# Create StyleConfig instances
self.style_configs = {}
for style_name, style_data in config_data['styles'].items():
style_data.setdefault('node_color', '#1f77b4')
style_data.setdefault('edge_color', '#7f7f7f')
style_data.setdefault('node_size', 3000)
style_data.setdefault('font_size', 12)
style_data.setdefault('layout', 'spring')
self.style_configs[style_name] = StyleConfig(**style_data)
# Handle output configuration
if 'output' not in config_data:
config_data['output'] = {}
output_config_data = config_data['output']
output_config_data.setdefault('width', 1200)
output_config_data.setdefault('height', 800)
output_config_data.setdefault('dpi', 300)
output_config_data.setdefault('format', 'png')
output_config_data.setdefault('quality', 95)
# Create OutputConfig instance
self.output_config = OutputConfig(**output_config_data)
self.config = config_data
self.logger.info("Configuration loaded successfully")
except Exception as e:
self.logger.error(f"Error loading configuration: {str(e)}")
raise RuntimeError(f"Failed to load configuration: {str(e)}")
def setup_components(self):
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_config.name)
self.model = T5ForConditionalGeneration.from_pretrained(self.model_config.name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def setup_logging(self):
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
async def extract_components(self, text: str) -> List[str]:
inputs = self.tokenizer(
f"convert to diagram: {text}",
return_tensors="pt",
max_length=self.model_config.max_length,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=150,
num_beams=self.model_config.num_beams,
temperature=self.model_config.temperature,
top_k=self.model_config.top_k,
top_p=self.model_config.top_p
)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
components = [comp.strip() for comp in decoded.replace('->', ',').split(',')]
return [comp for comp in components if comp]
def save_to_database(self, user_id: int, description: str, diagram: Image.Image):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
image_path = f"diagrams/user_{user_id}/{timestamp}.png"
os.makedirs(os.path.dirname(image_path), exist_ok=True)
diagram.save(image_path)
cursor = self.db.conn.cursor()
cursor.execute('''
INSERT INTO diagrams (user_id, description, created_at, image_path)
VALUES (?, ?, ?, ?)
''', (user_id, description, datetime.now(), image_path))
self.db.conn.commit()
async def generate_diagram(self, text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
try:
components = await self.extract_components(text)
diagram = self.strategies[strategy].create_diagram(components, self.style_configs[style])
return diagram, "Diagram generated successfully!"
except Exception as e:
self.logger.error(f"Error generating diagram: {str(e)}")
return None, f"Error: {str(e)}"
# FastAPI Integration
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class DiagramRequest(BaseModel):
text: str
style: str = "network"
strategy: str = "network"
# Initialize the generator
generator = AdvancedDiagramGenerator()
# Gradio Interface
def create_gradio_interface():
def generate(text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
return asyncio.run(generator.generate_diagram(text, style, strategy))
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
label="Enter your diagram description",
placeholder="e.g., 'Create a flowchart for software development lifecycle'",
lines=3
),
gr.Dropdown(
choices=list(generator.style_configs.keys()),
label="Diagram Style",
value="network"
),
gr.Dropdown(
choices=list(generator.strategies.keys()),
label="Visualization Strategy",
value="network"
)
],
outputs=[
gr.Image(label="Generated Diagram", type="pil"),
gr.Textbox(label="Status")
],
title="Advanced Enterprise Diagram Generator",
description="""
Enterprise-grade diagram generation tool with advanced features:
- Multiple visualization strategies
- Caching system
- Database storage
- Multiple output formats
- Custom styling options
""",
theme=gr.themes.Glass()
)
return iface
if __name__ == "__main__":
iface = create_gradio_interface()
iface.launch(share=True)