Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import io | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from scipy.spatial import distance | |
class DiagramGenerator: | |
def __init__(self): | |
# Initialize device | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load model | |
self.model_name = "t5-small" | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device) | |
# Initialize vectorizer | |
self.vectorizer = TfidfVectorizer(stop_words='english') | |
# Style configurations | |
self.styles = { | |
"flowchart": { | |
"node_color": "lightblue", | |
"edge_color": "gray", | |
"node_size": 3000 | |
}, | |
"mindmap": { | |
"node_color": "lightgreen", | |
"edge_color": "darkgreen", | |
"node_size": 2500 | |
}, | |
"sequence": { | |
"node_color": "lightyellow", | |
"edge_color": "orange", | |
"node_size": 3500 | |
}, | |
"kga": { | |
"node_color": "lightcoral", | |
"edge_color": "darkred", | |
"node_size": 3000 | |
} | |
} | |
def extract_components(self, text: str) -> list: | |
"""Extract components from text using T5 model.""" | |
inputs = self.tokenizer( | |
text, | |
max_length=512, | |
truncation=True, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.model.generate( | |
inputs['input_ids'], | |
num_beams=4, | |
max_length=512 | |
) | |
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return [comp.strip() for comp in decoded_output.split(",")] | |
def create_diagram(self, text: str, style: str = "flowchart"): | |
"""Create diagram from text with specified style.""" | |
try: | |
# Extract components | |
components = self.extract_components(text) | |
if not components: | |
return None, "No components extracted from text." | |
# Create figure | |
plt.figure(figsize=(12, 8)) | |
G = nx.DiGraph() | |
if style == "kga": | |
# Create KGA diagram | |
tfidf_matrix = self.vectorizer.fit_transform(components) | |
similarity_matrix = 1 - distance.squareform( | |
distance.pdist(tfidf_matrix.toarray(), metric='cosine') | |
) | |
# Add edges based on similarity | |
for i in range(len(components)): | |
for j in range(i + 1, len(components)): | |
if similarity_matrix[i][j] > 0.5: | |
G.add_edge(components[i], components[j]) | |
G.add_edge(components[j], components[i]) | |
else: | |
# Create sequential diagram | |
for i in range(len(components)-1): | |
G.add_edge(components[i], components[i+1]) | |
# Draw diagram | |
pos = nx.spring_layout(G) | |
style_config = self.styles[style] | |
nx.draw_networkx_nodes( | |
G, pos, | |
node_color=style_config['node_color'], | |
node_size=style_config['node_size'] | |
) | |
nx.draw_networkx_edges( | |
G, pos, | |
edge_color=style_config['edge_color'], | |
arrows=True if style != "kga" else False | |
) | |
nx.draw_networkx_labels(G, pos) | |
plt.title(f"{style.capitalize()} Diagram") | |
plt.axis('off') | |
# Save to buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) | |
plt.close() | |
buf.seek(0) | |
return Image.open(buf), "Diagram generated successfully!" | |
except Exception as e: | |
return None, f"Error generating diagram: {str(e)}" | |
def create_gradio_interface(): | |
generator = DiagramGenerator() | |
iface = gr.Interface( | |
fn=generator.create_diagram, | |
inputs=[ | |
gr.Textbox( | |
label="Enter your diagram description", | |
placeholder="e.g., 'Create a knowledge graph for artificial intelligence concepts'", | |
lines=3 | |
), | |
gr.Dropdown( | |
choices=list(generator.styles.keys()), | |
label="Diagram Style", | |
value="flowchart" | |
) | |
], | |
outputs=[ | |
gr.Image(label="Generated Diagram", type="pil"), | |
gr.Textbox(label="Status") | |
], | |
title="AI-Powered Diagram Generator", | |
description=""" | |
Create various types of diagrams from text descriptions. | |
Supports flowcharts, mindmaps, sequence diagrams, and knowledge graphs. | |
""" | |
) | |
return iface | |
if __name__ == "__main__": | |
iface = create_gradio_interface() | |
iface.launch() |