Graphify / entity_relationship_generator.py
ZahirJS's picture
Update entity_relationship_generator.py
e6eff2e verified
import graphviz
import json
from tempfile import NamedTemporaryFile
import os
def generate_entity_relationship_diagram(json_input: str, output_format: str) -> str:
"""
Generates an Entity Relationship (ER) diagram from JSON input.
Args:
json_input (str): A JSON string describing the ER diagram structure.
It must follow the Expected JSON Format Example below.
output_format (str): The output format for the generated diagram.
Supported formats: "png" or "svg"
Expected JSON Format Example:
{
"entities": [
{
"name": "Person",
"type": "strong",
"attributes": [
{
"name": "person_id",
"type": "primary_key"
},
{
"name": "first_name",
"type": "regular"
},
{
"name": "last_name",
"type": "regular"
},
{
"name": "birth_date",
"type": "regular"
},
{
"name": "age",
"type": "derived"
},
{
"name": "phone_numbers",
"type": "multivalued"
},
{
"name": "full_address",
"type": "composite"
}
]
},
{
"name": "Student",
"type": "strong",
"attributes": [
{
"name": "student_number",
"type": "regular"
},
{
"name": "enrollment_date",
"type": "regular"
},
{
"name": "gpa",
"type": "derived"
}
]
},
{
"name": "UndergraduateStudent",
"type": "strong",
"attributes": [
{
"name": "major",
"type": "regular"
},
{
"name": "expected_graduation",
"type": "regular"
},
{
"name": "credits_completed",
"type": "regular"
}
]
},
{
"name": "GraduateStudent",
"type": "strong",
"attributes": [
{
"name": "thesis_topic",
"type": "regular"
},
{
"name": "advisor_id",
"type": "regular"
},
{
"name": "degree_type",
"type": "regular"
}
]
},
{
"name": "Faculty",
"type": "strong",
"attributes": [
{
"name": "employee_number",
"type": "regular"
},
{
"name": "hire_date",
"type": "regular"
},
{
"name": "office_number",
"type": "regular"
},
{
"name": "years_of_service",
"type": "derived"
}
]
},
{
"name": "Professor",
"type": "strong",
"attributes": [
{
"name": "rank",
"type": "regular"
},
{
"name": "tenure_status",
"type": "regular"
},
{
"name": "research_areas",
"type": "multivalued"
}
]
},
{
"name": "Lecturer",
"type": "strong",
"attributes": [
{
"name": "contract_type",
"type": "regular"
},
{
"name": "courses_per_semester",
"type": "regular"
}
]
},
{
"name": "Staff",
"type": "strong",
"attributes": [
{
"name": "position_title",
"type": "regular"
},
{
"name": "department_assigned",
"type": "regular"
},
{
"name": "salary_grade",
"type": "regular"
}
]
},
{
"name": "AdministrativeStaff",
"type": "strong",
"attributes": [
{
"name": "access_level",
"type": "regular"
},
{
"name": "responsibilities",
"type": "multivalued"
}
]
},
{
"name": "TechnicalStaff",
"type": "strong",
"attributes": [
{
"name": "certifications",
"type": "multivalued"
},
{
"name": "equipment_assigned",
"type": "multivalued"
}
]
},
{
"name": "Vehicle",
"type": "strong",
"attributes": [
{
"name": "vehicle_id",
"type": "primary_key"
},
{
"name": "license_plate",
"type": "regular"
},
{
"name": "year",
"type": "regular"
},
{
"name": "current_value",
"type": "derived"
}
]
},
{
"name": "Car",
"type": "strong",
"attributes": [
{
"name": "doors",
"type": "regular"
},
{
"name": "fuel_type",
"type": "regular"
}
]
},
{
"name": "Bus",
"type": "strong",
"attributes": [
{
"name": "capacity",
"type": "regular"
},
{
"name": "route_assigned",
"type": "regular"
}
]
},
{
"name": "MaintenanceVehicle",
"type": "strong",
"attributes": [
{
"name": "equipment_type",
"type": "regular"
},
{
"name": "specialized_tools",
"type": "multivalued"
}
]
},
{
"name": "Course",
"type": "strong",
"attributes": [
{
"name": "course_id",
"type": "primary_key"
},
{
"name": "course_name",
"type": "regular"
},
{
"name": "credits",
"type": "regular"
}
]
},
{
"name": "Department",
"type": "strong",
"attributes": [
{
"name": "dept_id",
"type": "primary_key"
},
{
"name": "dept_name",
"type": "regular"
},
{
"name": "budget",
"type": "regular"
}
]
}
],
"relationships": [
{
"name": "PersonISA",
"type": "isa",
"parent": "Person",
"children": ["Student", "Faculty", "Staff"]
},
{
"name": "StudentISA",
"type": "isa",
"parent": "Student",
"children": ["UndergraduateStudent", "GraduateStudent"]
},
{
"name": "FacultyISA",
"type": "isa",
"parent": "Faculty",
"children": ["Professor", "Lecturer"]
},
{
"name": "StaffISA",
"type": "isa",
"parent": "Staff",
"children": ["AdministrativeStaff", "TechnicalStaff"]
},
{
"name": "VehicleISA",
"type": "isa",
"parent": "Vehicle",
"children": ["Car", "Bus", "MaintenanceVehicle"]
},
{
"name": "Enrolls",
"type": "regular",
"entities": ["Student", "Course"],
"cardinalities": {
"Student": "M",
"Course": "M"
},
"attributes": [
{
"name": "semester"
},
{
"name": "year"
},
{
"name": "grade"
}
]
},
{
"name": "Teaches",
"type": "regular",
"entities": ["Faculty", "Course"],
"cardinalities": {
"Faculty": "M",
"Course": "M"
},
"attributes": [
{
"name": "semester"
},
{
"name": "classroom"
}
]
},
{
"name": "WorksIn",
"type": "regular",
"entities": ["Faculty", "Department"],
"cardinalities": {
"Faculty": "M",
"Department": "1"
},
"attributes": [
{
"name": "start_date"
}
]
},
{
"name": "Manages",
"type": "regular",
"entities": ["Staff", "Department"],
"cardinalities": {
"Staff": "M",
"Department": "M"
},
"attributes": [
{
"name": "role"
}
]
},
{
"name": "Uses",
"type": "regular",
"entities": ["Staff", "Vehicle"],
"cardinalities": {
"Staff": "M",
"Vehicle": "M"
},
"attributes": [
{
"name": "usage_date"
},
{
"name": "purpose"
}
]
}
]
}
Returns:
str: The filepath to the generated image file.
"""
try:
if not json_input.strip():
return "Error: Empty input"
data = json.loads(json_input)
if 'entities' not in data:
raise ValueError("Missing required field: entities")
dot = graphviz.Graph(comment='ER Diagram', engine='neato')
dot.attr(
bgcolor='white',
pad='1.5',
overlap='false',
splines='true',
sep='+25',
esep='+15'
)
dot.attr('node', fontname='Arial', fontsize='10', color='#404040')
dot.attr('edge', fontname='Arial', fontsize='9', color='#4a4a4a')
entity_color = '#BEBEBE'
attribute_color = '#B8D4F1'
relationship_color = '#FFF9C4'
isa_color = '#A8E6CF'
font_color = 'black'
entities = data.get('entities', [])
relationships = data.get('relationships', [])
for entity in entities:
entity_name = entity.get('name')
entity_type = entity.get('type', 'strong')
attributes = entity.get('attributes', [])
if not entity_name:
continue
entity_color = '#BEBEBE'
if entity_type == 'weak':
dot.node(
entity_name,
entity_name,
shape='box',
style='filled,rounded',
fillcolor=entity_color,
fontcolor=font_color,
color='#404040',
penwidth='3',
width='1.8',
height='0.8',
fontsize='12'
)
else:
dot.node(
entity_name,
entity_name,
shape='box',
style='filled,rounded',
fillcolor=entity_color,
fontcolor=font_color,
color='#404040',
penwidth='1',
width='1.8',
height='0.8',
fontsize='12'
)
for i, attr in enumerate(attributes):
attr_name = attr.get('name', '')
attr_type = attr.get('type', 'regular')
attr_id = f"{entity_name}_attr_{i}"
attr_color = attribute_color
if attr_type == 'primary_key':
dot.node(
attr_id,
f'{attr_name} (PK)',
shape='ellipse',
style='filled,rounded',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'partial_key':
dot.node(
attr_id,
f'{attr_name} (Partial)',
shape='ellipse',
style='filled,rounded,dashed',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'multivalued':
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
penwidth='3',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'derived':
dot.node(
attr_id,
f'/{attr_name}/',
shape='ellipse',
style='filled,rounded,dashed',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
elif attr_type == 'composite':
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
else:
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.2',
height='0.6',
fontsize='10'
)
dot.edge(entity_name, attr_id, color='#4a4a4a', len='1.5')
for relationship in relationships:
rel_name = relationship.get('name')
rel_type = relationship.get('type', 'regular')
entities_involved = relationship.get('entities', [])
cardinalities = relationship.get('cardinalities', {})
rel_attributes = relationship.get('attributes', [])
if not rel_name:
continue
if rel_type == 'isa':
parent = relationship.get('parent')
children = relationship.get('children', [])
if parent and children:
isa_id = f"isa_{rel_name}"
isa_color = isa_color
dot.node(
isa_id,
'ISA',
shape='triangle',
style='filled,rounded',
fillcolor=isa_color,
fontcolor=font_color,
color='#404040',
penwidth='2',
width='1.0',
height='0.8',
fontsize='10'
)
dot.edge(parent, isa_id, color='#4a4a4a', len='2.0')
for child in children:
dot.edge(isa_id, child, color='#4a4a4a', len='2.0')
elif len(entities_involved) >= 2:
rel_color = relationship_color
if rel_type == 'identifying':
dot.node(
rel_name,
rel_name,
shape='diamond',
style='filled,rounded',
fillcolor=rel_color,
fontcolor=font_color,
color='#404040',
penwidth='3',
width='1.8',
height='1.0',
fontsize='11'
)
else:
dot.node(
rel_name,
rel_name,
shape='diamond',
style='filled,rounded',
fillcolor=rel_color,
fontcolor=font_color,
color='#404040',
penwidth='1',
width='1.8',
height='1.0',
fontsize='11'
)
for j, attr in enumerate(rel_attributes):
attr_name = attr.get('name', '')
attr_id = f"{rel_name}_attr_{j}"
attr_color = attribute_color
dot.node(
attr_id,
attr_name,
shape='ellipse',
style='filled,rounded',
fillcolor=attr_color,
fontcolor=font_color,
color='#404040',
width='1.0',
height='0.5',
fontsize='9'
)
dot.edge(rel_name, attr_id, color='#4a4a4a', len='1.0')
for entity in entities_involved:
cardinality = cardinalities.get(entity, '1')
dot.edge(
entity,
rel_name,
label=f' {cardinality} ',
color='#4a4a4a',
len='2.5',
fontcolor='#4a4a4a',
fontsize='10'
)
with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp:
dot.render(tmp.name, format=output_format, cleanup=True)
return f"{tmp.name}.{output_format}"
except json.JSONDecodeError:
return "Error: Invalid JSON format"
except Exception as e:
return f"Error: {str(e)}"