ZahirJS commited on
Commit
06e522b
·
verified ·
1 Parent(s): 2eb272a

Update class_diagram_generator.py

Browse files
Files changed (1) hide show
  1. class_diagram_generator.py +37 -59
class_diagram_generator.py CHANGED
@@ -199,18 +199,10 @@ def generate_class_diagram(json_input: str, output_format: str) -> str:
199
  if 'classes' not in data:
200
  raise ValueError("Missing required field: classes")
201
 
202
- dot = graphviz.Digraph(
203
- name='ClassDiagram',
204
- format='png',
205
- graph_attr={
206
- 'rankdir': 'TB',
207
- 'splines': 'ortho',
208
- 'bgcolor': 'white',
209
- 'pad': '0.5',
210
- 'nodesep': '1.0',
211
- 'ranksep': '1.5'
212
- }
213
- )
214
 
215
  classes = data.get('classes', [])
216
  relationships = data.get('relationships', [])
@@ -222,44 +214,40 @@ def generate_class_diagram(json_input: str, output_format: str) -> str:
222
  methods = cls.get('methods', [])
223
 
224
  if not class_name:
225
- raise ValueError(f"Invalid class: {cls}")
226
 
227
- class_parts = []
228
 
229
- header = class_name
230
  if class_type == 'abstract':
231
- header = f"<<abstract>>\\n{class_name}"
232
  elif class_type == 'interface':
233
- header = f"<<interface>>\\n{class_name}"
234
  elif class_type == 'enum':
235
- header = f"<<enumeration>>\\n{class_name}"
236
-
237
- class_parts.append(header)
238
 
239
  if attributes:
240
- attr_section = ""
241
  for attr in attributes:
242
  visibility = attr.get('visibility', '+')
243
  name = attr.get('name', '')
244
  attr_type = attr.get('type', '')
245
  is_static = attr.get('static', False)
246
 
247
- attr_line = f"{visibility} {name}"
248
  if attr_type:
249
- attr_line += f" : {attr_type}"
250
  if is_static:
251
- attr_line += " [static]"
252
 
253
- if attr_section:
254
- attr_section += "\\l"
255
- attr_section += attr_line
256
 
257
- if attr_section:
258
- attr_section += "\\l"
259
- class_parts.append(attr_section)
260
 
261
  if methods:
262
- method_section = ""
263
  for method in methods:
264
  visibility = method.get('visibility', '+')
265
  name = method.get('name', '')
@@ -268,49 +256,39 @@ def generate_class_diagram(json_input: str, output_format: str) -> str:
268
  is_static = method.get('static', False)
269
  is_abstract = method.get('abstract', False)
270
 
271
- method_line = f"{visibility} {name}("
272
  if parameters:
273
- param_strs = []
274
  for param in parameters:
275
  param_name = param.get('name', '')
276
  param_type = param.get('type', '')
277
- param_strs.append(f"{param_name}: {param_type}")
278
- method_line += ", ".join(param_strs)
279
- method_line += f") : {return_type}"
280
 
281
  if is_static:
282
- method_line += " [static]"
283
  if is_abstract:
284
- method_line += " [abstract]"
285
 
286
- if method_section:
287
- method_section += "\\l"
288
- method_section += method_line
289
 
290
- if method_section:
291
- method_section += "\\l"
292
- class_parts.append(method_section)
293
 
294
- class_label = "|".join(class_parts)
295
 
296
  if class_type == 'interface':
297
- style = 'filled,dashed'
298
  fillcolor = '#f5f5f5'
 
299
  elif class_type == 'abstract':
300
- style = 'filled'
301
  fillcolor = '#eeeeee'
302
- else:
303
  style = 'filled'
 
304
  fillcolor = 'white'
 
305
 
306
- dot.node(
307
- class_name,
308
- class_label,
309
- shape='record',
310
- style=style,
311
- fillcolor=fillcolor,
312
- color='black'
313
- )
314
 
315
  for relationship in relationships:
316
  from_class = relationship.get('from')
@@ -320,10 +298,10 @@ def generate_class_diagram(json_input: str, output_format: str) -> str:
320
  multiplicity_from = relationship.get('multiplicity_from', '')
321
  multiplicity_to = relationship.get('multiplicity_to', '')
322
 
323
- if not all([from_class, to_class]):
324
- raise ValueError(f"Invalid relationship: {relationship}")
325
 
326
- edge_attrs = {'color': 'black'}
327
 
328
  if label:
329
  edge_attrs['label'] = label
 
199
  if 'classes' not in data:
200
  raise ValueError("Missing required field: classes")
201
 
202
+ dot = graphviz.Digraph(comment='Class Diagram')
203
+ dot.attr(rankdir='TB', bgcolor='white', pad='0.5')
204
+ dot.attr('node', shape='box', style='filled', fillcolor='white', color='black', fontname='Arial', fontsize='10')
205
+ dot.attr('edge', color='black', fontname='Arial', fontsize='9')
 
 
 
 
 
 
 
 
206
 
207
  classes = data.get('classes', [])
208
  relationships = data.get('relationships', [])
 
214
  methods = cls.get('methods', [])
215
 
216
  if not class_name:
217
+ continue
218
 
219
+ label_parts = []
220
 
 
221
  if class_type == 'abstract':
222
+ label_parts.append(f"&lt;&lt;abstract&gt;&gt;\\n{class_name}")
223
  elif class_type == 'interface':
224
+ label_parts.append(f"&lt;&lt;interface&gt;&gt;\\n{class_name}")
225
  elif class_type == 'enum':
226
+ label_parts.append(f"&lt;&lt;enumeration&gt;&gt;\\n{class_name}")
227
+ else:
228
+ label_parts.append(class_name)
229
 
230
  if attributes:
231
+ attr_lines = []
232
  for attr in attributes:
233
  visibility = attr.get('visibility', '+')
234
  name = attr.get('name', '')
235
  attr_type = attr.get('type', '')
236
  is_static = attr.get('static', False)
237
 
238
+ line = f"{visibility} {name}"
239
  if attr_type:
240
+ line += f" : {attr_type}"
241
  if is_static:
242
+ line += " [static]"
243
 
244
+ attr_lines.append(line)
 
 
245
 
246
+ if attr_lines:
247
+ label_parts.append("\\n".join(attr_lines))
 
248
 
249
  if methods:
250
+ method_lines = []
251
  for method in methods:
252
  visibility = method.get('visibility', '+')
253
  name = method.get('name', '')
 
256
  is_static = method.get('static', False)
257
  is_abstract = method.get('abstract', False)
258
 
259
+ line = f"{visibility} {name}("
260
  if parameters:
261
+ params = []
262
  for param in parameters:
263
  param_name = param.get('name', '')
264
  param_type = param.get('type', '')
265
+ params.append(f"{param_name}: {param_type}")
266
+ line += ", ".join(params)
267
+ line += f") : {return_type}"
268
 
269
  if is_static:
270
+ line += " [static]"
271
  if is_abstract:
272
+ line += " [abstract]"
273
 
274
+ method_lines.append(line)
 
 
275
 
276
+ if method_lines:
277
+ label_parts.append("\\n".join(method_lines))
 
278
 
279
+ label = "\\n\\n".join(label_parts)
280
 
281
  if class_type == 'interface':
 
282
  fillcolor = '#f5f5f5'
283
+ style = 'filled,dashed'
284
  elif class_type == 'abstract':
 
285
  fillcolor = '#eeeeee'
 
286
  style = 'filled'
287
+ else:
288
  fillcolor = 'white'
289
+ style = 'filled'
290
 
291
+ dot.node(class_name, label, fillcolor=fillcolor, style=style)
 
 
 
 
 
 
 
292
 
293
  for relationship in relationships:
294
  from_class = relationship.get('from')
 
298
  multiplicity_from = relationship.get('multiplicity_from', '')
299
  multiplicity_to = relationship.get('multiplicity_to', '')
300
 
301
+ if not from_class or not to_class:
302
+ continue
303
 
304
+ edge_attrs = {}
305
 
306
  if label:
307
  edge_attrs['label'] = label