Anupam251272 commited on
Commit
e5e0d10
·
verified ·
1 Parent(s): d467493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -350
app.py CHANGED
@@ -1,398 +1,159 @@
1
  import gradio as gr
2
  import torch
3
- import logging
4
- import yaml
5
- import os
6
- import json
7
- import jwt
8
- import redis
9
- import sqlite3
10
- from datetime import datetime, timedelta
11
- from pathlib import Path
12
  from transformers import AutoTokenizer, T5ForConditionalGeneration
13
  import networkx as nx
14
  import matplotlib.pyplot as plt
15
  import numpy as np
16
  from PIL import Image
17
  import io
18
- import traceback
19
- from typing import Tuple, Optional, Dict, List, Union
20
- import colorama
21
- from colorama import Fore, Style
22
- from abc import ABC, abstractmethod
23
- from dataclasses import dataclass
24
- import plotly.graph_objects as go
25
- import hashlib
26
- import asyncio
27
- import aiohttp
28
- from fastapi import FastAPI, HTTPException, Depends, status
29
- from fastapi.security import OAuth2PasswordBearer
30
- from pydantic import BaseModel, EmailStr
31
- import uvicorn
32
 
33
- # Advanced Configuration Models
34
- @dataclass
35
- class ModelConfig:
36
- name: str
37
- max_length: int
38
- num_beams: int
39
- temperature: float
40
- top_k: int
41
- top_p: float
42
-
43
- @dataclass
44
- class StyleConfig:
45
- node_color: str
46
- edge_color: str
47
- node_size: int
48
- font_size: int
49
- layout: str
50
-
51
- @dataclass
52
- class OutputConfig:
53
- width: int
54
- height: int
55
- dpi: int
56
- format: str
57
- quality: int
58
-
59
- # Abstract Base Classes for Extensibility
60
- class DiagramStrategy(ABC):
61
- @abstractmethod
62
- def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
63
- pass
64
-
65
- class NetworkDiagram(DiagramStrategy):
66
- def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
67
- G = nx.DiGraph()
68
- for i in range(len(components)-1):
69
- G.add_edge(components[i], components[i+1])
70
-
71
- plt.figure(figsize=(12, 8))
72
 
73
- if style.layout == "spring":
74
- pos = nx.spring_layout(G)
75
- elif style.layout == "circular":
76
- pos = nx.circular_layout(G)
77
- else:
78
- pos = nx.kamada_kawai_layout(G)
79
-
80
- nx.draw_networkx_nodes(G, pos,
81
- node_color=style.node_color,
82
- node_size=style.node_size)
83
- nx.draw_networkx_edges(G, pos,
84
- edge_color=style.edge_color,
85
- arrows=True)
86
- nx.draw_networkx_labels(G, pos,
87
- font_size=style.font_size)
88
-
89
- buf = io.BytesIO()
90
- plt.savefig(buf, format='png', dpi=300)
91
- plt.close()
92
- buf.seek(0)
93
- return Image.open(buf)
94
-
95
- class PlotlyDiagram(DiagramStrategy):
96
- def create_diagram(self, components: List[str], style: StyleConfig) -> Image.Image:
97
- G = nx.DiGraph()
98
- for i in range(len(components)-1):
99
- G.add_edge(components[i], components[i+1])
100
-
101
- pos = nx.spring_layout(G)
102
 
103
- edge_x = []
104
- edge_y = []
105
- for edge in G.edges():
106
- x0, y0 = pos[edge[0]]
107
- x1, y1 = pos[edge[1]]
108
- edge_x.extend([x0, x1, None])
109
- edge_y.extend([y0, y1, None])
110
-
111
- node_x = [pos[node][0] for node in G.nodes()]
112
- node_y = [pos[node][1] for node in G.nodes()]
113
-
114
- fig = go.Figure()
115
- fig.add_trace(go.Scatter(x=edge_x, y=edge_y,
116
- line=dict(width=0.5, color=style.edge_color),
117
- hoverinfo='none',
118
- mode='lines'))
119
 
120
- fig.add_trace(go.Scatter(x=node_x, y=node_y,
121
- mode='markers+text',
122
- marker=dict(size=style.node_size/100,
123
- color=style.node_color),
124
- text=list(G.nodes()),
125
- textposition="bottom center"))
126
-
127
- fig.update_layout(showlegend=False,
128
- hovermode='closest',
129
- margin=dict(b=0,l=0,r=0,t=0))
130
-
131
- img_bytes = fig.to_image(format="png") # Requires kaleido
132
- return Image.open(io.BytesIO(img_bytes))
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Database Manager
135
- class DatabaseManager:
136
- def __init__(self, db_path: str = "diagrams.db"):
137
- self.conn = sqlite3.connect(db_path)
138
- self.create_tables()
 
 
 
139
 
140
- def create_tables(self):
141
- cursor = self.conn.cursor()
142
- cursor.execute('''
143
- CREATE TABLE IF NOT EXISTS users (
144
- id INTEGER PRIMARY KEY,
145
- username TEXT UNIQUE,
146
- email TEXT UNIQUE,
147
- password_hash TEXT,
148
- created_at TIMESTAMP
149
  )
150
- ''')
151
-
152
- cursor.execute('''
153
- CREATE TABLE IF NOT EXISTS diagrams (
154
- id INTEGER PRIMARY KEY,
155
- user_id INTEGER,
156
- title TEXT,
157
- description TEXT,
158
- created_at TIMESTAMP,
159
- image_path TEXT,
160
- FOREIGN KEY (user_id) REFERENCES users (id)
161
- )
162
- ''')
163
- self.conn.commit()
164
-
165
- # Cache Manager using Redis
166
- class CacheManager:
167
- def __init__(self, redis_url: str = "redis://localhost"):
168
- self.redis_client = redis.from_url(redis_url)
169
 
170
- def get_cached_diagram(self, key: str) -> Optional[bytes]:
171
- return self.redis_client.get(key)
172
-
173
- def cache_diagram(self, key: str, diagram: bytes, expire: int = 3600):
174
- self.redis_client.set(key, diagram, ex=expire)
175
-
176
- # Advanced Diagram Generator
177
- class AdvancedDiagramGenerator:
178
- def __init__(self, config_path: str = "config.yaml"):
179
- # Initialize logging first
180
- self.setup_logging()
181
-
182
- # Load configuration
183
- self.load_config(config_path)
184
-
185
- # Setup components (tokenizer, model, etc.)
186
- self.setup_components()
187
-
188
- # Initialize diagram strategies
189
- self.strategies = {
190
- "network": NetworkDiagram(),
191
- "plotly": PlotlyDiagram()
192
- }
193
 
194
- def load_config(self, config_path: str):
 
195
  try:
196
- if not os.path.exists(config_path):
197
- self.logger.warning(f"Config file not found at {config_path}. Using default configuration.")
198
- # Default configuration
199
- config_data = {
200
- 'model': {
201
- 'name': 't5-small',
202
- 'max_length': 512,
203
- 'num_beams': 4,
204
- 'temperature': 1.0,
205
- 'top_k': 50,
206
- 'top_p': 0.9
207
- },
208
- 'styles': {
209
- 'network': {
210
- 'node_color': '#1f77b4',
211
- 'edge_color': '#7f7f7f',
212
- 'node_size': 3000,
213
- 'font_size': 12,
214
- 'layout': 'spring'
215
- }
216
- },
217
- 'output': {
218
- 'width': 1200,
219
- 'height': 800,
220
- 'dpi': 300,
221
- 'format': 'png',
222
- 'quality': 95
223
- }
224
- }
225
  else:
226
- with open(config_path) as f:
227
- config_data = yaml.safe_load(f)
 
228
 
229
- # Ensure all required sections exist with defaults
230
- if 'model' not in config_data:
231
- config_data['model'] = {}
232
-
233
- # Set default values for model configuration
234
- model_config_data = config_data['model']
235
- model_config_data.setdefault('name', 't5-small')
236
- model_config_data.setdefault('max_length', 512)
237
- model_config_data.setdefault('num_beams', 4)
238
- model_config_data.setdefault('temperature', 1.0)
239
- model_config_data.setdefault('top_k', 50)
240
- model_config_data.setdefault('top_p', 0.9)
241
 
242
- # Create ModelConfig instance
243
- self.model_config = ModelConfig(**model_config_data)
244
-
245
- # Handle styles configuration
246
- if 'styles' not in config_data:
247
- config_data['styles'] = {
248
- 'network': {
249
- 'node_color': '#1f77b4',
250
- 'edge_color': '#7f7f7f',
251
- 'node_size': 3000,
252
- 'font_size': 12,
253
- 'layout': 'spring'
254
- }
255
- }
256
 
257
- # Create StyleConfig instances
258
- self.style_configs = {}
259
- for style_name, style_data in config_data['styles'].items():
260
- style_data.setdefault('node_color', '#1f77b4')
261
- style_data.setdefault('edge_color', '#7f7f7f')
262
- style_data.setdefault('node_size', 3000)
263
- style_data.setdefault('font_size', 12)
264
- style_data.setdefault('layout', 'spring')
265
- self.style_configs[style_name] = StyleConfig(**style_data)
266
-
267
- # Handle output configuration
268
- if 'output' not in config_data:
269
- config_data['output'] = {}
270
 
271
- output_config_data = config_data['output']
272
- output_config_data.setdefault('width', 1200)
273
- output_config_data.setdefault('height', 800)
274
- output_config_data.setdefault('dpi', 300)
275
- output_config_data.setdefault('format', 'png')
276
- output_config_data.setdefault('quality', 95)
 
 
 
277
 
278
- # Create OutputConfig instance
279
- self.output_config = OutputConfig(**output_config_data)
280
-
281
- self.config = config_data
282
- self.logger.info("Configuration loaded successfully")
283
-
284
- except Exception as e:
285
- self.logger.error(f"Error loading configuration: {str(e)}")
286
- raise RuntimeError(f"Failed to load configuration: {str(e)}")
287
-
288
- def setup_components(self):
289
- # Initialize tokenizer and model
290
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_config.name)
291
- self.model = T5ForConditionalGeneration.from_pretrained(self.model_config.name)
292
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
293
- self.model.to(self.device)
294
-
295
- def setup_logging(self):
296
- logging.basicConfig(level=logging.INFO)
297
- self.logger = logging.getLogger(__name__)
298
-
299
- async def extract_components(self, text: str) -> List[str]:
300
- inputs = self.tokenizer(
301
- f"convert to diagram: {text}",
302
- return_tensors="pt",
303
- max_length=self.model_config.max_length,
304
- truncation=True
305
- ).to(self.device)
306
-
307
- with torch.no_grad():
308
- outputs = self.model.generate(
309
- inputs.input_ids,
310
- max_length=150,
311
- num_beams=self.model_config.num_beams,
312
- temperature=self.model_config.temperature,
313
- top_k=self.model_config.top_k,
314
- top_p=self.model_config.top_p
315
- )
316
-
317
- decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
318
- components = [comp.strip() for comp in decoded.replace('->', ',').split(',')]
319
- return [comp for comp in components if comp]
320
 
321
- def save_to_database(self, user_id: int, description: str, diagram: Image.Image):
322
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
323
- image_path = f"diagrams/user_{user_id}/{timestamp}.png"
324
- os.makedirs(os.path.dirname(image_path), exist_ok=True)
325
- diagram.save(image_path)
326
-
327
- cursor = self.db.conn.cursor()
328
- cursor.execute('''
329
- INSERT INTO diagrams (user_id, description, created_at, image_path)
330
- VALUES (?, ?, ?, ?)
331
- ''', (user_id, description, datetime.now(), image_path))
332
- self.db.conn.commit()
333
-
334
- async def generate_diagram(self, text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
335
- try:
336
- components = await self.extract_components(text)
337
- diagram = self.strategies[strategy].create_diagram(components, self.style_configs[style])
338
- return diagram, "Diagram generated successfully!"
339
  except Exception as e:
340
- self.logger.error(f"Error generating diagram: {str(e)}")
341
- return None, f"Error: {str(e)}"
342
 
343
- # FastAPI Integration
344
- app = FastAPI()
345
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
346
-
347
- class DiagramRequest(BaseModel):
348
- text: str
349
- style: str = "network"
350
- strategy: str = "network"
351
-
352
- # Initialize the generator
353
- generator = AdvancedDiagramGenerator()
354
-
355
- # Gradio Interface
356
  def create_gradio_interface():
357
- def generate(text: str, style: str, strategy: str) -> Tuple[Optional[Image.Image], str]:
358
- return asyncio.run(generator.generate_diagram(text, style, strategy))
359
-
360
  iface = gr.Interface(
361
- fn=generate,
362
  inputs=[
363
  gr.Textbox(
364
  label="Enter your diagram description",
365
- placeholder="e.g., 'Create a flowchart for software development lifecycle'",
366
  lines=3
367
  ),
368
  gr.Dropdown(
369
- choices=list(generator.style_configs.keys()),
370
  label="Diagram Style",
371
- value="network"
372
- ),
373
- gr.Dropdown(
374
- choices=list(generator.strategies.keys()),
375
- label="Visualization Strategy",
376
- value="network"
377
  )
378
  ],
379
  outputs=[
380
  gr.Image(label="Generated Diagram", type="pil"),
381
  gr.Textbox(label="Status")
382
  ],
383
- title="Advanced Enterprise Diagram Generator",
384
  description="""
385
- Enterprise-grade diagram generation tool with advanced features:
386
- - Multiple visualization strategies
387
- - Caching system
388
- - Database storage
389
- - Multiple output formats
390
- - Custom styling options
391
- """,
392
- theme=gr.themes.Glass()
393
  )
394
  return iface
395
 
396
  if __name__ == "__main__":
397
  iface = create_gradio_interface()
398
- iface.launch(share=True)
 
1
  import gradio as gr
2
  import torch
 
 
 
 
 
 
 
 
 
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
  import networkx as nx
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  from PIL import Image
8
  import io
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
+ from scipy.spatial import distance
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ class DiagramGenerator:
13
+ def __init__(self):
14
+ # Initialize device
15
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Load model
18
+ self.model_name = "t5-small"
19
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
20
+ self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Initialize vectorizer
23
+ self.vectorizer = TfidfVectorizer(stop_words='english')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Style configurations
26
+ self.styles = {
27
+ "flowchart": {
28
+ "node_color": "lightblue",
29
+ "edge_color": "gray",
30
+ "node_size": 3000
31
+ },
32
+ "mindmap": {
33
+ "node_color": "lightgreen",
34
+ "edge_color": "darkgreen",
35
+ "node_size": 2500
36
+ },
37
+ "sequence": {
38
+ "node_color": "lightyellow",
39
+ "edge_color": "orange",
40
+ "node_size": 3500
41
+ },
42
+ "kga": {
43
+ "node_color": "lightcoral",
44
+ "edge_color": "darkred",
45
+ "node_size": 3000
46
+ }
47
+ }
48
 
49
+ def extract_components(self, text: str) -> list:
50
+ """Extract components from text using T5 model."""
51
+ inputs = self.tokenizer(
52
+ text,
53
+ max_length=512,
54
+ truncation=True,
55
+ return_tensors="pt"
56
+ ).to(self.device)
57
 
58
+ outputs = self.model.generate(
59
+ inputs['input_ids'],
60
+ num_beams=4,
61
+ max_length=512
 
 
 
 
 
62
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
65
+ return [comp.strip() for comp in decoded_output.split(",")]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ def create_diagram(self, text: str, style: str = "flowchart"):
68
+ """Create diagram from text with specified style."""
69
  try:
70
+ # Extract components
71
+ components = self.extract_components(text)
72
+ if not components:
73
+ return None, "No components extracted from text."
74
+
75
+ # Create figure
76
+ plt.figure(figsize=(12, 8))
77
+ G = nx.DiGraph()
78
+
79
+ if style == "kga":
80
+ # Create KGA diagram
81
+ tfidf_matrix = self.vectorizer.fit_transform(components)
82
+ similarity_matrix = 1 - distance.squareform(
83
+ distance.pdist(tfidf_matrix.toarray(), metric='cosine')
84
+ )
85
+
86
+ # Add edges based on similarity
87
+ for i in range(len(components)):
88
+ for j in range(i + 1, len(components)):
89
+ if similarity_matrix[i][j] > 0.5:
90
+ G.add_edge(components[i], components[j])
91
+ G.add_edge(components[j], components[i])
 
 
 
 
 
 
 
92
  else:
93
+ # Create sequential diagram
94
+ for i in range(len(components)-1):
95
+ G.add_edge(components[i], components[i+1])
96
 
97
+ # Draw diagram
98
+ pos = nx.spring_layout(G)
99
+ style_config = self.styles[style]
 
 
 
 
 
 
 
 
 
100
 
101
+ nx.draw_networkx_nodes(
102
+ G, pos,
103
+ node_color=style_config['node_color'],
104
+ node_size=style_config['node_size']
105
+ )
 
 
 
 
 
 
 
 
 
106
 
107
+ nx.draw_networkx_edges(
108
+ G, pos,
109
+ edge_color=style_config['edge_color'],
110
+ arrows=True if style != "kga" else False
111
+ )
 
 
 
 
 
 
 
 
112
 
113
+ nx.draw_networkx_labels(G, pos)
114
+ plt.title(f"{style.capitalize()} Diagram")
115
+ plt.axis('off')
116
+
117
+ # Save to buffer
118
+ buf = io.BytesIO()
119
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
120
+ plt.close()
121
+ buf.seek(0)
122
 
123
+ return Image.open(buf), "Diagram generated successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
+ return None, f"Error generating diagram: {str(e)}"
 
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def create_gradio_interface():
129
+ generator = DiagramGenerator()
130
+
 
131
  iface = gr.Interface(
132
+ fn=generator.create_diagram,
133
  inputs=[
134
  gr.Textbox(
135
  label="Enter your diagram description",
136
+ placeholder="e.g., 'Create a knowledge graph for artificial intelligence concepts'",
137
  lines=3
138
  ),
139
  gr.Dropdown(
140
+ choices=list(generator.styles.keys()),
141
  label="Diagram Style",
142
+ value="flowchart"
 
 
 
 
 
143
  )
144
  ],
145
  outputs=[
146
  gr.Image(label="Generated Diagram", type="pil"),
147
  gr.Textbox(label="Status")
148
  ],
149
+ title="AI-Powered Diagram Generator",
150
  description="""
151
+ Create various types of diagrams from text descriptions.
152
+ Supports flowcharts, mindmaps, sequence diagrams, and knowledge graphs.
153
+ """
 
 
 
 
 
154
  )
155
  return iface
156
 
157
  if __name__ == "__main__":
158
  iface = create_gradio_interface()
159
+ iface.launch()