Graphify / entity_relationship_generator.py
ZahirJS's picture
Update entity_relationship_generator.py
c655400 verified
raw
history blame
12.8 kB
import graphviz
import json
from tempfile import NamedTemporaryFile
import os
def generate_entity_relationship_diagram(json_input: str, output_format: str) -> str:
try:
if not json_input.strip():
return "Error: Empty input"
data = json.loads(json_input)
if 'entities' not in data:
raise ValueError("Missing required field: entities")
dot = graphviz.Graph(comment='ER Diagram', engine='neato')
dot.attr(
bgcolor='white',
pad='1.5',
overlap='false',
splines='true',
sep='+25',
esep='+15'
)
dot.attr('node', fontname='Arial', fontsize='10', color='#404040')
dot.attr('edge', fontname='Arial', fontsize='9', color='#4a4a4a')
# Base color system - much lighter grays starting from #BEBEBE
base_color = '#BEBEBE' # Much lighter base color
lightening_factor = 0.08 # Smaller factor for subtle gradations
def get_gradient_color(depth, base_hex_color, lightening_factor=0.08):
"""Get lightened color based on depth - much lighter grays"""
if not isinstance(base_hex_color, str) or not base_hex_color.startswith('#') or len(base_hex_color) != 7:
base_hex_color = '#BEBEBE' # Light gray fallback
base_r = int(base_hex_color[1:3], 16)
base_g = int(base_hex_color[3:5], 16)
base_b = int(base_hex_color[5:7], 16)
current_r = base_r + int((255 - base_r) * depth * lightening_factor)
current_g = base_g + int((255 - base_g) * depth * lightening_factor)
current_b = base_b + int((255 - base_b) * depth * lightening_factor)
current_r = min(255, current_r)
current_g = min(255, current_g)
current_b = min(255, current_b)
return f'#{current_r:02x}{current_g:02x}{current_b:02x}'
def get_font_color_for_background(depth, lightening_factor=0.08):
"""Get appropriate font color based on background lightness"""
# With lighter base colors, we'll use black text for better readability
return 'black'
entities = data.get('entities', [])
relationships = data.get('relationships', [])
# Process entities with new styling
for entity in entities:
entity_name = entity.get('name')
entity_type = entity.get('type', 'strong')
attributes = entity.get('attributes', [])
if not entity_name:
continue
# Entity colors - depth 1 for entities
entity_color = get_gradient_color(1, base_color, lightening_factor)
entity_font_color = get_font_color_for_background(1, lightening_factor)
if entity_type == 'weak':
dot.node(
entity_name,
entity_name,
shape='box',
style='filled,rounded', # KEY CHANGE: rounded borders like your other generators
fillcolor=entity_color,
fontcolor=entity_font_color,
color='#404040',
penwidth='3',
width='1.8',
height='0.8',
fontsize='12'
)
else:
dot.node(
entity_name,
entity_name,
shape='box',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=entity_color,
fontcolor=entity_font_color,
color='#404040',
penwidth='1',
width='1.8',
height='0.8',
fontsize='12'
)
# Process attributes with gradient colors - depth 2 for attributes
for i, attr in enumerate(attributes):
attr_name = attr.get('name', '')
attr_type = attr.get('type', 'regular')
attr_id = f"{entity_name}_attr_{i}"
attr_color = get_gradient_color(2, base_color, lightening_factor)
attr_font_color = get_font_color_for_background(2, lightening_factor)
if attr_type == 'primary_key':
dot.node(
attr_id,
f'{attr_name} (PK)',
shape='ellipse',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'partial_key':
dot.node(
attr_id,
f'{attr_name} (Partial)',
shape='ellipse',
style='filled,rounded,dashed', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'multivalued':
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
penwidth='3',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'derived':
dot.node(
attr_id,
f'/{attr_name}/',
shape='ellipse',
style='filled,rounded,dashed', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'composite':
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
else: # regular
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=attr_color,
fontcolor=attr_font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
dot.edge(entity_name, attr_id, color='#4a4a4a', len='1.5')
# Process relationships with gradient colors - depth 1.5 for relationships
for relationship in relationships:
rel_name = relationship.get('name')
rel_type = relationship.get('type', 'regular')
entities_involved = relationship.get('entities', [])
cardinalities = relationship.get('cardinalities', {})
rel_attributes = relationship.get('attributes', [])
if not rel_name:
continue
rel_color = get_gradient_color(1.5, base_color, lightening_factor)
rel_font_color = get_font_color_for_background(1.5, lightening_factor)
if rel_type == 'isa':
parent = relationship.get('parent')
children = relationship.get('children', [])
if parent and children:
isa_id = f"isa_{rel_name}"
isa_color = get_gradient_color(1.2, base_color, lightening_factor)
isa_font_color = get_font_color_for_background(1.2, lightening_factor)
dot.node(
isa_id,
'ISA',
shape='triangle',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=isa_color,
fontcolor=isa_font_color,
color='#404040',
penwidth='2',
width='1.0',
height='0.8',
fontsize='10'
)
dot.edge(parent, isa_id, color='#4a4a4a', len='2.0')
for child in children:
dot.edge(isa_id, child, color='#4a4a4a', len='2.0')
elif len(entities_involved) >= 2:
if rel_type == 'identifying':
dot.node(
rel_name,
rel_name,
shape='diamond',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=rel_color,
fontcolor=rel_font_color,
color='#404040',
penwidth='3',
width='1.8',
height='1.0',
fontsize='11'
)
else:
dot.node(
rel_name,
rel_name,
shape='diamond',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=rel_color,
fontcolor=rel_font_color,
color='#404040',
penwidth='1',
width='1.8',
height='1.0',
fontsize='11'
)
# Relationship attributes - depth 2.5
for j, attr in enumerate(rel_attributes):
attr_name = attr.get('name', '')
attr_id = f"{rel_name}_attr_{j}"
rel_attr_color = get_gradient_color(2.5, base_color, lightening_factor)
rel_attr_font_color = get_font_color_for_background(2.5, lightening_factor)
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded', # KEY CHANGE: rounded borders
fillcolor=rel_attr_color,
fontcolor=rel_attr_font_color,
color='#404040',
width='1.0',
height='0.5',
fontsize='9'
)
dot.edge(rel_name, attr_id, color='#4a4a4a', len='1.0')
for entity in entities_involved:
cardinality = cardinalities.get(entity, '1')
dot.edge(
entity,
rel_name,
label=f' {cardinality} ',
color='#4a4a4a',
len='2.5',
fontcolor='#4a4a4a',
fontsize='10'
)
with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
dot.render(tmp.name, format=output_format, cleanup=True)
return f"{tmp.name}.{output_format}"
except json.JSONDecodeError:
return "Error: Invalid JSON format"
except Exception as e:
return f"Error: {str(e)}"