ZahirJS commited on
Commit
5f8bd16
·
verified ·
1 Parent(s): 9ed7b2c

Create class_diagram_generator.py

Browse files
Files changed (1) hide show
  1. class_diagram_generator.py +171 -0
class_diagram_generator.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import graphviz
2
+ import json
3
+ from tempfile import NamedTemporaryFile
4
+ import os
5
+
6
+ def generate_class_diagram(json_input: str, output_format: str) -> str:
7
+ try:
8
+ if not json_input.strip():
9
+ return "Error: Empty input"
10
+
11
+ data = json.loads(json_input)
12
+
13
+ if 'classes' not in data:
14
+ raise ValueError("Missing required field: classes")
15
+
16
+ dot = graphviz.Digraph(
17
+ name='ClassDiagram',
18
+ format='png',
19
+ graph_attr={
20
+ 'rankdir': 'TB',
21
+ 'splines': 'ortho',
22
+ 'bgcolor': 'white',
23
+ 'pad': '0.5',
24
+ 'nodesep': '1.5',
25
+ 'ranksep': '2.0'
26
+ }
27
+ )
28
+
29
+ base_color = '#19191a'
30
+ lightening_factor = 0.15
31
+
32
+ classes = data.get('classes', [])
33
+ relationships = data.get('relationships', [])
34
+
35
+ for i, cls in enumerate(classes):
36
+ class_name = cls.get('name')
37
+ class_type = cls.get('type', 'class')
38
+ attributes = cls.get('attributes', [])
39
+ methods = cls.get('methods', [])
40
+
41
+ if not class_name:
42
+ raise ValueError(f"Invalid class: {cls}")
43
+
44
+ current_depth = i % 6
45
+
46
+ if not isinstance(base_color, str) or not base_color.startswith('#') or len(base_color) != 7:
47
+ base_color_safe = '#19191a'
48
+ else:
49
+ base_color_safe = base_color
50
+
51
+ base_r = int(base_color_safe[1:3], 16)
52
+ base_g = int(base_color_safe[3:5], 16)
53
+ base_b = int(base_color_safe[5:7], 16)
54
+
55
+ current_r = base_r + int((255 - base_r) * current_depth * lightening_factor)
56
+ current_g = base_g + int((255 - base_g) * current_depth * lightening_factor)
57
+ current_b = base_b + int((255 - base_b) * current_depth * lightening_factor)
58
+
59
+ current_r = min(255, current_r)
60
+ current_g = min(255, current_g)
61
+ current_b = min(255, current_b)
62
+
63
+ node_color = f'#{current_r:02x}{current_g:02x}{current_b:02x}'
64
+ font_color = 'white' if current_depth * lightening_factor < 0.6 else 'black'
65
+
66
+ class_label = ""
67
+
68
+ if class_type == 'abstract':
69
+ class_label += "<<abstract>>\\n"
70
+ elif class_type == 'interface':
71
+ class_label += "<<interface>>\\n"
72
+ elif class_type == 'enum':
73
+ class_label += "<<enumeration>>\\n"
74
+
75
+ class_label += f"{class_name}\\l"
76
+
77
+ if attributes:
78
+ class_label += "\\l"
79
+ for attr in attributes:
80
+ visibility = attr.get('visibility', '+')
81
+ name = attr.get('name', '')
82
+ attr_type = attr.get('type', '')
83
+ is_static = attr.get('static', False)
84
+
85
+ attr_line = f"{visibility} "
86
+ if is_static:
87
+ attr_line += f"<<static>> "
88
+ attr_line += f"{name}"
89
+ if attr_type:
90
+ attr_line += f" : {attr_type}"
91
+ class_label += f"{attr_line}\\l"
92
+
93
+ if methods:
94
+ class_label += "\\l"
95
+ for method in methods:
96
+ visibility = method.get('visibility', '+')
97
+ name = method.get('name', '')
98
+ parameters = method.get('parameters', [])
99
+ return_type = method.get('return_type', 'void')
100
+ is_static = method.get('static', False)
101
+ is_abstract = method.get('abstract', False)
102
+
103
+ method_line = f"{visibility} "
104
+ if is_static:
105
+ method_line += f"<<static>> "
106
+ if is_abstract:
107
+ method_line += f"<<abstract>> "
108
+
109
+ method_line += f"{name}("
110
+ if parameters:
111
+ param_strs = []
112
+ for param in parameters:
113
+ param_name = param.get('name', '')
114
+ param_type = param.get('type', '')
115
+ param_strs.append(f"{param_name}: {param_type}")
116
+ method_line += ", ".join(param_strs)
117
+ method_line += f") : {return_type}"
118
+ class_label += f"{method_line}\\l"
119
+
120
+ if class_type == 'interface':
121
+ style = 'filled,dashed'
122
+ else:
123
+ style = 'filled'
124
+
125
+ dot.node(
126
+ class_name,
127
+ class_label,
128
+ shape='record',
129
+ style=style,
130
+ fillcolor=node_color,
131
+ fontcolor=font_color,
132
+ fontsize='10',
133
+ fontname='Helvetica'
134
+ )
135
+
136
+ for relationship in relationships:
137
+ from_class = relationship.get('from')
138
+ to_class = relationship.get('to')
139
+ rel_type = relationship.get('type', 'association')
140
+ label = relationship.get('label', '')
141
+ multiplicity_from = relationship.get('multiplicity_from', '')
142
+ multiplicity_to = relationship.get('multiplicity_to', '')
143
+
144
+ if not all([from_class, to_class]):
145
+ raise ValueError(f"Invalid relationship: {relationship}")
146
+
147
+ edge_label = label
148
+ if multiplicity_from or multiplicity_to:
149
+ edge_label += f"\\n{multiplicity_from} --- {multiplicity_to}"
150
+
151
+ if rel_type == 'inheritance':
152
+ dot.edge(from_class, to_class, arrowhead='empty', color='#4a4a4a', label=edge_label, fontsize='9')
153
+ elif rel_type == 'composition':
154
+ dot.edge(from_class, to_class, arrowhead='normal', arrowtail='diamond', dir='both', color='#4a4a4a', label=edge_label, fontsize='9')
155
+ elif rel_type == 'aggregation':
156
+ dot.edge(from_class, to_class, arrowhead='normal', arrowtail='odiamond', dir='both', color='#4a4a4a', label=edge_label, fontsize='9')
157
+ elif rel_type == 'realization':
158
+ dot.edge(from_class, to_class, arrowhead='empty', style='dashed', color='#4a4a4a', label=edge_label, fontsize='9')
159
+ elif rel_type == 'dependency':
160
+ dot.edge(from_class, to_class, arrowhead='normal', style='dashed', color='#4a4a4a', label=edge_label, fontsize='9')
161
+ else:
162
+ dot.edge(from_class, to_class, arrowhead='normal', color='#4a4a4a', label=edge_label, fontsize='9')
163
+
164
+ with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
165
+ dot.render(tmp.name, format=output_format, cleanup=True)
166
+ return f"{tmp.name}.{output_format}"
167
+
168
+ except json.JSONDecodeError:
169
+ return "Error: Invalid JSON format"
170
+ except Exception as e:
171
+ return f"Error: {str(e)}"