File size: 5,361 Bytes
29ae6c8
 
 
 
 
 
 
 
e5e0d10
 
29ae6c8
e5e0d10
 
 
 
29ae6c8
e5e0d10
 
 
 
29ae6c8
e5e0d10
 
29ae6c8
e5e0d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ae6c8
e5e0d10
 
 
 
 
 
 
 
29ae6c8
e5e0d10
 
 
 
29ae6c8
 
e5e0d10
 
29ae6c8
e5e0d10
 
29ae6c8
e5e0d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29ae6c8
e5e0d10
 
 
29ae6c8
e5e0d10
 
 
29ae6c8
e5e0d10
 
 
 
 
29ae6c8
e5e0d10
 
 
 
 
29ae6c8
e5e0d10
 
 
 
 
 
 
 
 
29ae6c8
e5e0d10
29ae6c8
 
e5e0d10
29ae6c8
 
e5e0d10
 
29ae6c8
e5e0d10
29ae6c8
 
 
e5e0d10
29ae6c8
 
 
e5e0d10
29ae6c8
e5e0d10
29ae6c8
 
 
 
 
 
e5e0d10
29ae6c8
e5e0d10
 
 
29ae6c8
 
 
 
 
e5e0d10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial import distance

class DiagramGenerator:
    def __init__(self):
        # Initialize device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Load model
        self.model_name = "t5-small"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name).to(self.device)
        
        # Initialize vectorizer
        self.vectorizer = TfidfVectorizer(stop_words='english')
        
        # Style configurations
        self.styles = {
            "flowchart": {
                "node_color": "lightblue",
                "edge_color": "gray",
                "node_size": 3000
            },
            "mindmap": {
                "node_color": "lightgreen",
                "edge_color": "darkgreen",
                "node_size": 2500
            },
            "sequence": {
                "node_color": "lightyellow",
                "edge_color": "orange",
                "node_size": 3500
            },
            "kga": {
                "node_color": "lightcoral",
                "edge_color": "darkred",
                "node_size": 3000
            }
        }

    def extract_components(self, text: str) -> list:
        """Extract components from text using T5 model."""
        inputs = self.tokenizer(
            text,
            max_length=512,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        outputs = self.model.generate(
            inputs['input_ids'],
            num_beams=4,
            max_length=512
        )

        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return [comp.strip() for comp in decoded_output.split(",")]

    def create_diagram(self, text: str, style: str = "flowchart"):
        """Create diagram from text with specified style."""
        try:
            # Extract components
            components = self.extract_components(text)
            if not components:
                return None, "No components extracted from text."

            # Create figure
            plt.figure(figsize=(12, 8))
            G = nx.DiGraph()

            if style == "kga":
                # Create KGA diagram
                tfidf_matrix = self.vectorizer.fit_transform(components)
                similarity_matrix = 1 - distance.squareform(
                    distance.pdist(tfidf_matrix.toarray(), metric='cosine')
                )
                
                # Add edges based on similarity
                for i in range(len(components)):
                    for j in range(i + 1, len(components)):
                        if similarity_matrix[i][j] > 0.5:
                            G.add_edge(components[i], components[j])
                            G.add_edge(components[j], components[i])
            else:
                # Create sequential diagram
                for i in range(len(components)-1):
                    G.add_edge(components[i], components[i+1])

            # Draw diagram
            pos = nx.spring_layout(G)
            style_config = self.styles[style]
            
            nx.draw_networkx_nodes(
                G, pos,
                node_color=style_config['node_color'],
                node_size=style_config['node_size']
            )
            
            nx.draw_networkx_edges(
                G, pos,
                edge_color=style_config['edge_color'],
                arrows=True if style != "kga" else False
            )
            
            nx.draw_networkx_labels(G, pos)
            plt.title(f"{style.capitalize()} Diagram")
            plt.axis('off')

            # Save to buffer
            buf = io.BytesIO()
            plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
            plt.close()
            buf.seek(0)
            
            return Image.open(buf), "Diagram generated successfully!"

        except Exception as e:
            return None, f"Error generating diagram: {str(e)}"

def create_gradio_interface():
    generator = DiagramGenerator()
    
    iface = gr.Interface(
        fn=generator.create_diagram,
        inputs=[
            gr.Textbox(
                label="Enter your diagram description",
                placeholder="e.g., 'Create a knowledge graph for artificial intelligence concepts'",
                lines=3
            ),
            gr.Dropdown(
                choices=list(generator.styles.keys()),
                label="Diagram Style",
                value="flowchart"
            )
        ],
        outputs=[
            gr.Image(label="Generated Diagram", type="pil"),
            gr.Textbox(label="Status")
        ],
        title="AI-Powered Diagram Generator",
        description="""
        Create various types of diagrams from text descriptions.
        Supports flowcharts, mindmaps, sequence diagrams, and knowledge graphs.
        """
    )
    return iface

if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()