Anupam251272 commited on
Commit
29ae6c8
·
verified ·
1 Parent(s): 3e91042

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import logging
4
+ import yaml
5
+ import os
6
+ import json
7
+ import jwt
8
+ import redis
9
+ import sqlite3
10
+ from datetime import datetime, timedelta
11
+ from pathlib import Path
12
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
13
+ import networkx as nx
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ from PIL import Image
17
+ import io
18
+ import traceback
19
+ from typing import Tuple, Optional, Dict, List, Union
20
+ import colorama
21
+ from colorama import Fore, Style
22
+ from abc import ABC, abstractmethod
23
+ from dataclasses import dataclass
24
+ import plotly.graph_objects as go
25
+ import hashlib
26
+ import asyncio
27
+ import aiohttp
28
+ from fastapi import FastAPI, HTTPException, Depends, status
29
+ from fastapi.security import OAuth2PasswordBearer
30
+ from pydantic import BaseModel, EmailStr
31
+ import uvicorn
32
+
33
+ # Advanced Configuration Models
34
+ @dataclass
35
+ class ModelConfig:
36
+ name: str
37
+ max_length: int
38
+ num_beams: int
39
+ temperature: float
40
+ top_k: int
41
+ top_p: float
42
+
43
+ @dataclass
44
+ class StyleConfig:
45
+ node_color: str
46
+ edge_color: str
47
+ node_size: int
48
+ font_size: int
49
+ layout: str
50
+
51
+ @dataclass
52
+ class OutputConfig:
53
+ width: int
54
+ height: int
55
+ dpi: int
56
+ format: str
57
+ quality: int
58
+
59
+ # Abstract Base Classes for Extensibility
60
+ class DiagramStrategy(ABC):
61
+ @abstractmethod
62
+ def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
63
+ pass
64
+
65
+ class NetworkDiagram(DiagramStrategy):
66
+ def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
67
+ G = nx.DiGraph()
68
+ for i in range(len(components)-1):
69
+ G.add_edge(components[i], components[i+1])
70
+
71
+ plt.figure(figsize=(12, 8))
72
+
73
+ if style.layout == "spring":
74
+ pos = nx.spring_layout(G)
75
+ elif style.layout == "circular":
76
+ pos = nx.circular_layout(G)
77
+ else:
78
+ pos = nx.kamada_kawai_layout(G)
79
+
80
+ nx.draw_networkx_nodes(G, pos,
81
+ node_color=style.node_color,
82
+ node_size=style.node_size)
83
+ nx.draw_networkx_edges(G, pos,
84
+ edge_color=style.edge_color,
85
+ arrows=True)
86
+ nx.draw_networkx_labels(G, pos,
87
+ font_size=style.font_size)
88
+
89
+ buf = io.BytesIO()
90
+ plt.savefig(buf, format='png', dpi=300)
91
+ plt.close()
92
+ buf.seek(0)
93
+ return Image.open(buf)
94
+
95
+ class PlotlyDiagram(DiagramStrategy):
96
+ def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
97
+ G = nx.DiGraph()
98
+ for i in range(len(components)-1):
99
+ G.add_edge(components[i], components[i+1])
100
+
101
+ pos = nx.spring_layout(G)
102
+
103
+ edge_x = []
104
+ edge_y = []
105
+ for edge in G.edges():
106
+ x0, y0 = pos[edge[0]]
107
+ x1, y1 = pos[edge[1]]
108
+ edge_x.extend([x0, x1, None])
109
+ edge_y.extend([y0, y1, None])
110
+
111
+ node_x = [pos[node][0] for node in G.nodes()]
112
+ node_y = [pos[node][1] for node in G.nodes()]
113
+
114
+ fig = go.Figure()
115
+ fig.add_trace(go.Scatter(x=edge_x, y=edge_y,
116
+ line=dict(width=0.5, color=style.edge_color),
117
+ hoverinfo='none',
118
+ mode='lines'))
119
+
120
+ fig.add_trace(go.Scatter(x=node_x, y=node_y,
121
+ mode='markers+text',
122
+ marker=dict(size=style.node_size/100,
123
+ color=style.node_color),
124
+ text=list(G.nodes()),
125
+ textposition="bottom center"))
126
+
127
+ fig.update_layout(showlegend=False,
128
+ hovermode='closest',
129
+ margin=dict(b=0,l=0,r=0,t=0))
130
+
131
+ img_bytes = fig.to_image(format="png") # Requires kaleido
132
+ return Image.open(io.BytesIO(img_bytes))
133
+
134
+ # Database Manager
135
+ class DatabaseManager:
136
+ def __init__(self, db_path: str = "diagrams.db"):
137
+ self.conn = sqlite3.connect(db_path)
138
+ self.create_tables()
139
+
140
+ def create_tables(self):
141
+ cursor = self.conn.cursor()
142
+ cursor.execute('''
143
+ CREATE TABLE IF NOT EXISTS users (
144
+ id INTEGER PRIMARY KEY,
145
+ username TEXT UNIQUE,
146
+ email TEXT UNIQUE,
147
+ password_hash TEXT,
148
+ created_at TIMESTAMP
149
+ )
150
+ ''')
151
+
152
+ cursor.execute('''
153
+ CREATE TABLE IF NOT EXISTS diagrams (
154
+ id INTEGER PRIMARY KEY,
155
+ user_id INTEGER,
156
+ title TEXT,
157
+ description TEXT,
158
+ created_at TIMESTAMP,
159
+ image_path TEXT,
160
+ FOREIGN KEY (user_id) REFERENCES users (id)
161
+ )
162
+ ''')
163
+ self.conn.commit()
164
+
165
+ # Cache Manager using Redis
166
+ class CacheManager:
167
+ def __init__(self, redis_url: str = "redis://localhost"):
168
+ self.redis_client = redis.from_url(redis_url)
169
+
170
+ def get_cached_diagram(self, key: str) -> Optional[bytes]:
171
+ return self.redis_client.get(key)
172
+
173
+ def cache_diagram(self, key: str, diagram: bytes, expire: int = 3600):
174
+ self.redis_client.set(key, diagram, ex=expire)
175
+
176
+ # Advanced Diagram Generator
177
+ class AdvancedDiagramGenerator:
178
+ def __init__(self, config_path: str = "config.yaml"):
179
+ # Initialize logging first
180
+ self.setup_logging()
181
+
182
+ # Load configuration
183
+ self.load_config(config_path)
184
+
185
+ # Setup components (tokenizer, model, etc.)
186
+ self.setup_components()
187
+
188
+ # Initialize diagram strategies
189
+ self.strategies = {
190
+ "network": NetworkDiagram(),
191
+ "plotly": PlotlyDiagram()
192
+ }
193
+
194
+ def load_config(self, config_path: str):
195
+ try:
196
+ if not os.path.exists(config_path):
197
+ self.logger.warning(f"Config file not found at {config_path}. Using default configuration.")
198
+ # Default configuration
199
+ config_data = {
200
+ 'model': {
201
+ 'name': 't5-small',
202
+ 'max_length': 512,
203
+ 'num_beams': 4,
204
+ 'temperature': 1.0,
205
+ 'top_k': 50,
206
+ 'top_p': 0.9
207
+ },
208
+ 'styles': {
209
+ 'network': {
210
+ 'node_color': '#1f77b4',
211
+ 'edge_color': '#7f7f7f',
212
+ 'node_size': 3000,
213
+ 'font_size': 12,
214
+ 'layout': 'spring'
215
+ }
216
+ },
217
+ 'output': {
218
+ 'width': 1200,
219
+ 'height': 800,
220
+ 'dpi': 300,
221
+ 'format': 'png',
222
+ 'quality': 95
223
+ }
224
+ }
225
+ else:
226
+ with open(config_path) as f:
227
+ config_data = yaml.safe_load(f)
228
+
229
+ # Ensure all required sections exist with defaults
230
+ if 'model' not in config_data:
231
+ config_data['model'] = {}
232
+
233
+ # Set default values for model configuration
234
+ model_config_data = config_data['model']
235
+ model_config_data.setdefault('name', 't5-small')
236
+ model_config_data.setdefault('max_length', 512)
237
+ model_config_data.setdefault('num_beams', 4)
238
+ model_config_data.setdefault('temperature', 1.0)
239
+ model_config_data.setdefault('top_k', 50)
240
+ model_config_data.setdefault('top_p', 0.9)
241
+
242
+ # Create ModelConfig instance
243
+ self.model_config = ModelConfig(**model_config_data)
244
+
245
+ # Handle styles configuration
246
+ if 'styles' not in config_data:
247
+ config_data['styles'] = {
248
+ 'network': {
249
+ 'node_color': '#1f77b4',
250
+ 'edge_color': '#7f7f7f',
251
+ 'node_size': 3000,
252
+ 'font_size': 12,
253
+ 'layout': 'spring'
254
+ }
255
+ }
256
+
257
+ # Create StyleConfig instances
258
+ self.style_configs = {}
259
+ for style_name, style_data in config_data['styles'].items():
260
+ style_data.setdefault('node_color', '#1f77b4')
261
+ style_data.setdefault('edge_color', '#7f7f7f')
262
+ style_data.setdefault('node_size', 3000)
263
+ style_data.setdefault('font_size', 12)
264
+ style_data.setdefault('layout', 'spring')
265
+ self.style_configs[style_name] = StyleConfig(**style_data)
266
+
267
+ # Handle output configuration
268
+ if 'output' not in config_data:
269
+ config_data['output'] = {}
270
+
271
+ output_config_data = config_data['output']
272
+ output_config_data.setdefault('width', 1200)
273
+ output_config_data.setdefault('height', 800)
274
+ output_config_data.setdefault('dpi', 300)
275
+ output_config_data.setdefault('format', 'png')
276
+ output_config_data.setdefault('quality', 95)
277
+
278
+ # Create OutputConfig instance
279
+ self.output_config = OutputConfig(**output_config_data)
280
+
281
+ self.config = config_data
282
+ self.logger.info("Configuration loaded successfully")
283
+
284
+ except Exception as e:
285
+ self.logger.error(f"Error loading configuration: {str(e)}")
286
+ raise RuntimeError(f"Failed to load configuration: {str(e)}")
287
+
288
+ def setup_components(self):
289
+ # Initialize tokenizer and model
290
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_config.name)
291
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_config.name)
292
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
293
+ self.model.to(self.device)
294
+
295
+ def setup_logging(self):
296
+ logging.basicConfig(level=logging.INFO)
297
+ self.logger = logging.getLogger(__name__)
298
+
299
+ async def extract_components(self, text: str) -> List[str]:
300
+ inputs = self.tokenizer(
301
+ f"convert to diagram: {text}",
302
+ return_tensors="pt",
303
+ max_length=self.model_config.max_length,
304
+ truncation=True
305
+ ).to(self.device)
306
+
307
+ with torch.no_grad():
308
+ outputs = self.model.generate(
309
+ inputs.input_ids,
310
+ max_length=150,
311
+ num_beams=self.model_config.num_beams,
312
+ temperature=self.model_config.temperature,
313
+ top_k=self.model_config.top_k,
314
+ top_p=self.model_config.top_p
315
+ )
316
+
317
+ decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
318
+ components = [comp.strip() for comp in decoded.replace('->', ',').split(',')]
319
+ return [comp for comp in components if comp]
320
+
321
+ def save_to_database(self, user_id: int, description: str, diagram: Image.Image):
322
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
323
+ image_path = f"diagrams/user_{user_id}/{timestamp}.png"
324
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
325
+ diagram.save(image_path)
326
+
327
+ cursor = self.db.conn.cursor()
328
+ cursor.execute('''
329
+ INSERT INTO diagrams (user_id, description, created_at, image_path)
330
+ VALUES (?, ?, ?, ?)
331
+ ''', (user_id, description, datetime.now(), image_path))
332
+ self.db.conn.commit()
333
+
334
+ async def generate_diagram(self, text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
335
+ try:
336
+ components = await self.extract_components(text)
337
+ diagram = self.strategies[strategy].create_diagram(components, self.style_configs[style])
338
+ return diagram, "Diagram generated successfully!"
339
+ except Exception as e:
340
+ self.logger.error(f"Error generating diagram: {str(e)}")
341
+ return None, f"Error: {str(e)}"
342
+
343
+ # FastAPI Integration
344
+ app = FastAPI()
345
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
346
+
347
+ class DiagramRequest(BaseModel):
348
+ text: str
349
+ style: str = "network"
350
+ strategy: str = "network"
351
+
352
+ # Initialize the generator
353
+ generator = AdvancedDiagramGenerator()
354
+
355
+ # Gradio Interface
356
+ def create_gradio_interface():
357
+ def generate(text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
358
+ return asyncio.run(generator.generate_diagram(text, style, strategy))
359
+
360
+ iface = gr.Interface(
361
+ fn=generate,
362
+ inputs=[
363
+ gr.Textbox(
364
+ label="Enter your diagram description",
365
+ placeholder="e.g., 'Create a flowchart for software development lifecycle'",
366
+ lines=3
367
+ ),
368
+ gr.Dropdown(
369
+ choices=list(generator.style_configs.keys()),
370
+ label="Diagram Style",
371
+ value="network"
372
+ ),
373
+ gr.Dropdown(
374
+ choices=list(generator.strategies.keys()),
375
+ label="Visualization Strategy",
376
+ value="network"
377
+ )
378
+ ],
379
+ outputs=[
380
+ gr.Image(label="Generated Diagram", type="pil"),
381
+ gr.Textbox(label="Status")
382
+ ],
383
+ title="Advanced Enterprise Diagram Generator",
384
+ description="""
385
+ Enterprise-grade diagram generation tool with advanced features:
386
+ - Multiple visualization strategies
387
+ - Caching system
388
+ - Database storage
389
+ - Multiple output formats
390
+ - Custom styling options
391
+ """,
392
+ theme=gr.themes.Glass()
393
+ )
394
+ return iface
395
+
396
+ if __name__ == "__main__":
397
+ iface = create_gradio_interface()
398
+ iface.launch(share=True)