|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, T5ForConditionalGeneration |
|
import networkx as nx |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from scipy.spatial import distance |
|
|
|
class DiagramGenerator: |
|
def __init__(self): |
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.model_name = "t5-small" |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device) |
|
|
|
|
|
self.vectorizer = TfidfVectorizer(stop_words='english') |
|
|
|
|
|
self.styles = { |
|
"flowchart": { |
|
"node_color": "lightblue", |
|
"edge_color": "gray", |
|
"node_size": 3000 |
|
}, |
|
"mindmap": { |
|
"node_color": "lightgreen", |
|
"edge_color": "darkgreen", |
|
"node_size": 2500 |
|
}, |
|
"sequence": { |
|
"node_color": "lightyellow", |
|
"edge_color": "orange", |
|
"node_size": 3500 |
|
}, |
|
"kga": { |
|
"node_color": "lightcoral", |
|
"edge_color": "darkred", |
|
"node_size": 3000 |
|
} |
|
} |
|
|
|
def extract_components(self, text: str) -> list: |
|
"""Extract components from text using T5 model.""" |
|
inputs = self.tokenizer( |
|
text, |
|
max_length=512, |
|
truncation=True, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
outputs = self.model.generate( |
|
inputs['input_ids'], |
|
num_beams=4, |
|
max_length=512 |
|
) |
|
|
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return [comp.strip() for comp in decoded_output.split(",")] |
|
|
|
def create_diagram(self, text: str, style: str = "flowchart"): |
|
"""Create diagram from text with specified style.""" |
|
try: |
|
|
|
components = self.extract_components(text) |
|
if not components: |
|
return None, "No components extracted from text." |
|
|
|
|
|
plt.figure(figsize=(12, 8)) |
|
G = nx.DiGraph() |
|
|
|
if style == "kga": |
|
|
|
tfidf_matrix = self.vectorizer.fit_transform(components) |
|
similarity_matrix = 1 - distance.squareform( |
|
distance.pdist(tfidf_matrix.toarray(), metric='cosine') |
|
) |
|
|
|
|
|
for i in range(len(components)): |
|
for j in range(i + 1, len(components)): |
|
if similarity_matrix[i][j] > 0.5: |
|
G.add_edge(components[i], components[j]) |
|
G.add_edge(components[j], components[i]) |
|
else: |
|
|
|
for i in range(len(components)-1): |
|
G.add_edge(components[i], components[i+1]) |
|
|
|
|
|
pos = nx.spring_layout(G) |
|
style_config = self.styles[style] |
|
|
|
nx.draw_networkx_nodes( |
|
G, pos, |
|
node_color=style_config['node_color'], |
|
node_size=style_config['node_size'] |
|
) |
|
|
|
nx.draw_networkx_edges( |
|
G, pos, |
|
edge_color=style_config['edge_color'], |
|
arrows=True if style != "kga" else False |
|
) |
|
|
|
nx.draw_networkx_labels(G, pos) |
|
plt.title(f"{style.capitalize()} Diagram") |
|
plt.axis('off') |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100) |
|
plt.close() |
|
buf.seek(0) |
|
|
|
return Image.open(buf), "Diagram generated successfully!" |
|
|
|
except Exception as e: |
|
return None, f"Error generating diagram: {str(e)}" |
|
|
|
def create_gradio_interface(): |
|
generator = DiagramGenerator() |
|
|
|
iface = gr.Interface( |
|
fn=generator.create_diagram, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Enter your diagram description", |
|
placeholder="e.g., 'Create a knowledge graph for artificial intelligence concepts'", |
|
lines=3 |
|
), |
|
gr.Dropdown( |
|
choices=list(generator.styles.keys()), |
|
label="Diagram Style", |
|
value="flowchart" |
|
) |
|
], |
|
outputs=[ |
|
gr.Image(label="Generated Diagram", type="pil"), |
|
gr.Textbox(label="Status") |
|
], |
|
title="AI-Powered Diagram Generator", |
|
description=""" |
|
Create various types of diagrams from text descriptions. |
|
Supports flowcharts, mindmaps, sequence diagrams, and knowledge graphs. |
|
""" |
|
) |
|
return iface |
|
|
|
if __name__ == "__main__": |
|
iface = create_gradio_interface() |
|
iface.launch() |