ZahirJS commited on
Commit
0e101d5
·
verified ·
1 Parent(s): c9dc53b

Update entity_relationship_generator.py

Browse files
Files changed (1) hide show
  1. entity_relationship_generator.py +50 -68
entity_relationship_generator.py CHANGED
@@ -213,18 +213,30 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
213
  'splines': 'ortho',
214
  'bgcolor': 'white',
215
  'pad': '0.5',
216
- 'nodesep': '2.0',
217
- 'ranksep': '2.5'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  }
219
  )
220
 
221
- base_color = '#19191a'
222
- lightening_factor = 0.15
223
-
224
  entities = data.get('entities', [])
225
  relationships = data.get('relationships', [])
226
 
227
- for i, entity in enumerate(entities):
228
  entity_name = entity.get('name')
229
  entity_type = entity.get('type', 'strong')
230
  attributes = entity.get('attributes', [])
@@ -232,49 +244,28 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
232
  if not entity_name:
233
  raise ValueError(f"Invalid entity: {entity}")
234
 
235
- current_depth = i % 6
236
-
237
- if not isinstance(base_color, str) or not base_color.startswith('#') or len(base_color) != 7:
238
- base_color_safe = '#19191a'
239
- else:
240
- base_color_safe = base_color
241
-
242
- base_r = int(base_color_safe[1:3], 16)
243
- base_g = int(base_color_safe[3:5], 16)
244
- base_b = int(base_color_safe[5:7], 16)
245
-
246
- current_r = base_r + int((255 - base_r) * current_depth * lightening_factor)
247
- current_g = base_g + int((255 - base_g) * current_depth * lightening_factor)
248
- current_b = base_b + int((255 - base_b) * current_depth * lightening_factor)
249
-
250
- current_r = min(255, current_r)
251
- current_g = min(255, current_g)
252
- current_b = min(255, current_b)
253
-
254
- node_color = f'#{current_r:02x}{current_g:02x}{current_b:02x}'
255
- font_color = 'white' if current_depth * lightening_factor < 0.6 else 'black'
256
-
257
- entity_label = f"{entity_name}\\n"
258
 
259
  if attributes:
 
260
  primary_keys = []
261
  foreign_keys = []
262
  regular_attrs = []
263
 
264
  for attr in attributes:
265
  attr_name = attr.get('name', '')
266
- attr_type = attr.get('type', 'key')
267
  is_multivalued = attr.get('multivalued', False)
268
  is_derived = attr.get('derived', False)
269
  is_composite = attr.get('composite', False)
270
 
271
  if attr_type == 'primary_key':
272
  if is_multivalued:
273
- primary_keys.append(f"{{{{ {attr_name} }}}}")
274
  else:
275
- primary_keys.append(f"[PK] {attr_name}")
276
  elif attr_type == 'foreign_key':
277
- foreign_keys.append(f"[FK] {attr_name}")
278
  else:
279
  attr_display = attr_name
280
  if is_derived:
@@ -285,29 +276,27 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
285
  attr_display = f"( {attr_display} )"
286
  regular_attrs.append(attr_display)
287
 
288
- if primary_keys:
289
- entity_label += "\\n".join(primary_keys) + "\\n"
290
- if foreign_keys:
291
- entity_label += "\\n".join(foreign_keys) + "\\n"
292
- if regular_attrs:
293
- entity_label += "\\n".join(regular_attrs)
294
 
295
  if entity_type == 'weak':
296
- shape = 'doubleoctagon'
297
  style = 'filled'
 
 
298
  else:
299
- shape = 'box'
300
- style = 'filled,rounded'
 
 
301
 
302
  dot.node(
303
  entity_name,
304
- entity_label,
305
  shape=shape,
306
  style=style,
307
- fillcolor=node_color,
308
- fontcolor=font_color,
309
- fontsize='10',
310
- fontname='Helvetica'
311
  )
312
 
313
  for relationship in relationships:
@@ -319,48 +308,41 @@ def generate_entity_relationship_diagram(json_input: str, output_format: str) ->
319
  if not rel_name or len(entities_involved) < 2:
320
  raise ValueError(f"Invalid relationship: {relationship}")
321
 
322
- rel_node_color = '#e6f3ff'
323
-
324
  if rel_type == 'identifying':
325
  rel_shape = 'diamond'
326
- rel_style = 'filled,bold'
327
- rel_color = '#4a90e2'
 
328
  elif rel_type == 'weak':
329
  rel_shape = 'diamond'
330
- rel_style = 'filled,dashed'
331
- rel_color = '#a0a0a0'
 
332
  else:
333
  rel_shape = 'diamond'
334
  rel_style = 'filled'
335
- rel_color = '#4a90e2'
 
336
 
337
  dot.node(
338
  rel_name,
339
- rel_name,
340
  shape=rel_shape,
341
  style=rel_style,
342
  fillcolor=rel_color,
343
- fontcolor='white',
344
- fontsize='10',
345
- fontname='Helvetica'
346
  )
347
 
348
  for entity in entities_involved:
349
  cardinality = cardinalities.get(entity, '1')
350
 
351
- edge_label = cardinality
352
- if cardinality in ['1:1', '1:N', 'M:N', '1', 'N', 'M']:
353
- pass
354
- else:
355
- edge_label = cardinality
356
-
357
  dot.edge(
358
  entity,
359
  rel_name,
360
- label=edge_label,
361
- color='#4a4a4a',
362
- fontsize='9',
363
- fontcolor='#4a4a4a'
364
  )
365
 
366
  with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
 
213
  'splines': 'ortho',
214
  'bgcolor': 'white',
215
  'pad': '0.5',
216
+ 'nodesep': '1.2',
217
+ 'ranksep': '1.8',
218
+ 'fontname': 'Arial',
219
+ 'dpi': '300',
220
+ 'overlap': 'false'
221
+ },
222
+ node_attr={
223
+ 'fontname': 'Arial',
224
+ 'fontsize': '10',
225
+ 'color': 'black',
226
+ 'penwidth': '1.5'
227
+ },
228
+ edge_attr={
229
+ 'fontname': 'Arial',
230
+ 'fontsize': '9',
231
+ 'color': 'black',
232
+ 'penwidth': '1'
233
  }
234
  )
235
 
 
 
 
236
  entities = data.get('entities', [])
237
  relationships = data.get('relationships', [])
238
 
239
+ for entity in entities:
240
  entity_name = entity.get('name')
241
  entity_type = entity.get('type', 'strong')
242
  attributes = entity.get('attributes', [])
 
244
  if not entity_name:
245
  raise ValueError(f"Invalid entity: {entity}")
246
 
247
+ entity_label = f"<B>{entity_name}</B>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  if attributes:
250
+ entity_label += "|"
251
  primary_keys = []
252
  foreign_keys = []
253
  regular_attrs = []
254
 
255
  for attr in attributes:
256
  attr_name = attr.get('name', '')
257
+ attr_type = attr.get('type', 'regular')
258
  is_multivalued = attr.get('multivalued', False)
259
  is_derived = attr.get('derived', False)
260
  is_composite = attr.get('composite', False)
261
 
262
  if attr_type == 'primary_key':
263
  if is_multivalued:
264
+ primary_keys.append(f"<U>{{{{ {attr_name} }}}}</U>")
265
  else:
266
+ primary_keys.append(f"<U>{attr_name}</U>")
267
  elif attr_type == 'foreign_key':
268
+ foreign_keys.append(f"<I>{attr_name}</I> (FK)")
269
  else:
270
  attr_display = attr_name
271
  if is_derived:
 
276
  attr_display = f"( {attr_display} )"
277
  regular_attrs.append(attr_display)
278
 
279
+ all_attrs = primary_keys + foreign_keys + regular_attrs
280
+ entity_label += "\\l".join(all_attrs) + "\\l"
 
 
 
 
281
 
282
  if entity_type == 'weak':
283
+ shape = 'record'
284
  style = 'filled'
285
+ fillcolor = '#f0f0f0'
286
+ penwidth = '3'
287
  else:
288
+ shape = 'record'
289
+ style = 'filled'
290
+ fillcolor = 'white'
291
+ penwidth = '1.5'
292
 
293
  dot.node(
294
  entity_name,
295
+ f"<{entity_label}>",
296
  shape=shape,
297
  style=style,
298
+ fillcolor=fillcolor,
299
+ penwidth=penwidth
 
 
300
  )
301
 
302
  for relationship in relationships:
 
308
  if not rel_name or len(entities_involved) < 2:
309
  raise ValueError(f"Invalid relationship: {relationship}")
310
 
 
 
311
  if rel_type == 'identifying':
312
  rel_shape = 'diamond'
313
+ rel_style = 'filled'
314
+ rel_color = '#d0d0d0'
315
+ rel_penwidth = '3'
316
  elif rel_type == 'weak':
317
  rel_shape = 'diamond'
318
+ rel_style = 'filled'
319
+ rel_color = '#e8e8e8'
320
+ rel_penwidth = '2'
321
  else:
322
  rel_shape = 'diamond'
323
  rel_style = 'filled'
324
+ rel_color = '#d0d0d0'
325
+ rel_penwidth = '1.5'
326
 
327
  dot.node(
328
  rel_name,
329
+ f"<B>{rel_name}</B>",
330
  shape=rel_shape,
331
  style=rel_style,
332
  fillcolor=rel_color,
333
+ fontcolor='black',
334
+ penwidth=rel_penwidth
 
335
  )
336
 
337
  for entity in entities_involved:
338
  cardinality = cardinalities.get(entity, '1')
339
 
 
 
 
 
 
 
340
  dot.edge(
341
  entity,
342
  rel_name,
343
+ label=cardinality,
344
+ labelfontsize='10',
345
+ labelfontcolor='black'
 
346
  )
347
 
348
  with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp: