ZahirJS commited on
Commit
7474f3d
·
verified ·
1 Parent(s): 06e522b

Update entity_relationship_generator.py

Browse files
Files changed (1) hide show
  1. entity_relationship_generator.py +21 -64
entity_relationship_generator.py CHANGED
@@ -205,18 +205,10 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
205
  if 'entities' not in data:
206
  raise ValueError("Missing required field: entities")
207
 
208
- dot = graphviz.Graph(
209
- name='ERDiagram',
210
- format='png',
211
- graph_attr={
212
- 'rankdir': 'TB',
213
- 'splines': 'ortho',
214
- 'bgcolor': 'white',
215
- 'pad': '0.5',
216
- 'nodesep': '1.2',
217
- 'ranksep': '1.8'
218
- }
219
- )
220
 
221
  entities = data.get('entities', [])
222
  relationships = data.get('relationships', [])
@@ -227,12 +219,12 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
227
  attributes = entity.get('attributes', [])
228
 
229
  if not entity_name:
230
- raise ValueError(f"Invalid entity: {entity}")
231
 
232
- entity_parts = [entity_name]
233
 
234
  if attributes:
235
- attr_section = ""
236
  for attr in attributes:
237
  attr_name = attr.get('name', '')
238
  attr_type = attr.get('type', 'regular')
@@ -256,36 +248,23 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
256
  if is_composite:
257
  attr_display = f"( {attr_display} )"
258
 
259
- if attr_section:
260
- attr_section += "\\l"
261
- attr_section += attr_display
262
 
263
- if attr_section:
264
- attr_section += "\\l"
265
- entity_parts.append(attr_section)
266
 
267
- entity_label = "|".join(entity_parts)
268
 
269
  if entity_type == 'weak':
270
- shape = 'record'
271
  style = 'filled'
272
  fillcolor = '#e8e8e8'
273
  penwidth = '2'
274
  else:
275
- shape = 'record'
276
  style = 'filled'
277
  fillcolor = '#d0d0d0'
278
  penwidth = '1'
279
 
280
- dot.node(
281
- entity_name,
282
- entity_label,
283
- shape=shape,
284
- style=style,
285
- fillcolor=fillcolor,
286
- color='black',
287
- penwidth=penwidth
288
- )
289
 
290
  for relationship in relationships:
291
  rel_name = relationship.get('name')
@@ -294,44 +273,22 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
294
  cardinalities = relationship.get('cardinalities', {})
295
 
296
  if not rel_name or len(entities_involved) < 2:
297
- raise ValueError(f"Invalid relationship: {relationship}")
298
 
299
  if rel_type == 'identifying':
300
- rel_shape = 'diamond'
301
- rel_style = 'filled'
302
- rel_color = '#c0c0c0'
303
- rel_penwidth = '2'
304
- elif rel_type == 'weak':
305
- rel_shape = 'diamond'
306
- rel_style = 'filled'
307
- rel_color = '#e0e0e0'
308
- rel_penwidth = '1'
309
  else:
310
- rel_shape = 'diamond'
311
- rel_style = 'filled'
312
- rel_color = '#c0c0c0'
313
- rel_penwidth = '1'
314
 
315
- dot.node(
316
- rel_name,
317
- rel_name,
318
- shape=rel_shape,
319
- style=rel_style,
320
- fillcolor=rel_color,
321
- color='black',
322
- fontcolor='black',
323
- penwidth=rel_penwidth
324
- )
325
 
326
  for entity in entities_involved:
327
  cardinality = cardinalities.get(entity, '1')
328
-
329
- dot.edge(
330
- entity,
331
- rel_name,
332
- label=cardinality,
333
- color='black'
334
- )
335
 
336
  with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
337
  dot.render(tmp.name, format=output_format, cleanup=True)
 
205
  if 'entities' not in data:
206
  raise ValueError("Missing required field: entities")
207
 
208
+ dot = graphviz.Graph(comment='ER Diagram', engine='dot')
209
+ dot.attr(rankdir='TB', bgcolor='white', pad='0.5')
210
+ dot.attr('node', fontname='Arial', fontsize='10', color='black')
211
+ dot.attr('edge', fontname='Arial', fontsize='9', color='black')
 
 
 
 
 
 
 
 
212
 
213
  entities = data.get('entities', [])
214
  relationships = data.get('relationships', [])
 
219
  attributes = entity.get('attributes', [])
220
 
221
  if not entity_name:
222
+ continue
223
 
224
+ label_parts = [entity_name]
225
 
226
  if attributes:
227
+ attr_lines = []
228
  for attr in attributes:
229
  attr_name = attr.get('name', '')
230
  attr_type = attr.get('type', 'regular')
 
248
  if is_composite:
249
  attr_display = f"( {attr_display} )"
250
 
251
+ attr_lines.append(attr_display)
 
 
252
 
253
+ if attr_lines:
254
+ label_parts.extend(attr_lines)
 
255
 
256
+ label = "\\n".join(label_parts)
257
 
258
  if entity_type == 'weak':
 
259
  style = 'filled'
260
  fillcolor = '#e8e8e8'
261
  penwidth = '2'
262
  else:
 
263
  style = 'filled'
264
  fillcolor = '#d0d0d0'
265
  penwidth = '1'
266
 
267
+ dot.node(entity_name, label, shape='box', style=style, fillcolor=fillcolor, penwidth=penwidth)
 
 
 
 
 
 
 
 
268
 
269
  for relationship in relationships:
270
  rel_name = relationship.get('name')
 
273
  cardinalities = relationship.get('cardinalities', {})
274
 
275
  if not rel_name or len(entities_involved) < 2:
276
+ continue
277
 
278
  if rel_type == 'identifying':
279
+ style = 'filled'
280
+ fillcolor = '#c0c0c0'
281
+ penwidth = '2'
 
 
 
 
 
 
282
  else:
283
+ style = 'filled'
284
+ fillcolor = '#c0c0c0'
285
+ penwidth = '1'
 
286
 
287
+ dot.node(rel_name, rel_name, shape='diamond', style=style, fillcolor=fillcolor, penwidth=penwidth)
 
 
 
 
 
 
 
 
 
288
 
289
  for entity in entities_involved:
290
  cardinality = cardinalities.get(entity, '1')
291
+ dot.edge(entity, rel_name, label=cardinality)
 
 
 
 
 
 
292
 
293
  with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
294
  dot.render(tmp.name, format=output_format, cleanup=True)