Graphify / wbs_diagram_generator.py
ZahirJS's picture
Create wbs_diagram_generator.py
11561b8 verified
raw
history blame
6.7 kB
import graphviz
import json
from tempfile import NamedTemporaryFile
import os
from graph_generator_utils import add_nodes_and_edges # Reusing common utility
def generate_wbs_diagram(json_input: str, base_color: str) -> str:
"""
Generates a Work Breakdown Structure (WBS) Diagram from JSON input.
Args:
json_input (str): A JSON string describing the WBS structure.
It must follow the Expected JSON Format Example below.
base_color (str): The hexadecimal color string (e.g., '#19191a') for the base
color of the nodes, from which a gradient will be generated.
Returns:
str: The filepath to the generated PNG image file.
Expected JSON Format Example:
{
"project_title": "Software Development Project",
"phases": [
{
"id": "phase_prep",
"label": "1. Preparation",
"tasks": [
{"id": "task_vision", "label": "1.1. Identify Vision"},
{"id": "task_design", "label": "1.2. Design & Staffing"}
]
},
{
"id": "phase_plan",
"label": "2. Planning",
"tasks": [
{"id": "task_cost", "label": "2.1. Cost Analysis"},
{"id": "task_benefit", "label": "2.2. Benefit Analysis"},
{"id": "task_risk", "label": "2.3. Risk Assessment"}
]
},
{
"id": "phase_dev",
"label": "3. Development",
"tasks": [
{"id": "task_change", "label": "3.1. Change Management"},
{"id": "task_impl", "label": "3.2. Implementation"},
{"id": "task_beta", "label": "3.3. Beta Testing"}
]
}
]
}
"""
try:
if not json_input.strip():
return "Error: Empty input"
data = json.loads(json_input)
if 'project_title' not in data or 'phases' not in data:
raise ValueError("Missing required fields: project_title or phases")
dot = graphviz.Digraph(
name='WBSDiagram',
format='png',
graph_attr={
'rankdir': 'TB', # Top-to-Bottom hierarchy
'splines': 'ortho', # Straight lines
'bgcolor': 'white', # White background
'pad': '0.5', # Padding
'ranksep': '0.8', # Adjust vertical separation between ranks
'nodesep': '0.5' # Adjust horizontal separation between nodes
}
)
# Project Title node (main node)
dot.node(
'project_root',
data['project_title'],
shape='box',
style='filled,rounded',
fillcolor=base_color, # Use the selected base color
fontcolor='white',
fontsize='18'
)
# Add phases and their tasks
current_depth = 1 # Start depth for phases
for phase in data['phases']:
phase_id = phase.get('id')
phase_label = phase.get('label')
tasks = phase.get('tasks', [])
if not all([phase_id, phase_label]):
raise ValueError(f"Invalid phase: {phase}")
# Calculate color for phase node
# This logic is adapted from add_nodes_and_edges but applied to phases/tasks
# to keep consistency with the color gradient.
lightening_factor = 0.12
base_r = int(base_color[1:3], 16)
base_g = int(base_color[3:5], 16)
base_b = int(base_color[5:7], 16)
phase_r = base_r + int((255 - base_r) * current_depth * lightening_factor)
phase_g = base_g + int((255 - base_g) * current_depth * lightening_factor)
phase_b = base_b + int((255 - base_b) * current_depth * lightening_factor)
phase_fill_color = f'#{min(255, phase_r):02x}{min(255, phase_g):02x}{min(255, phase_b):02x}'
phase_font_color = 'white' if current_depth * lightening_factor < 0.6 else 'black'
dot.node(
phase_id,
phase_label,
shape='box',
style='filled,rounded',
fillcolor=phase_fill_color,
fontcolor=phase_font_color,
fontsize='14'
)
dot.edge('project_root', phase_id, color='#4a4a4a', arrowhead='none') # Connect to root
task_depth = current_depth + 1
task_r = base_r + int((255 - base_r) * task_depth * lightening_factor)
task_g = base_g + int((255 - base_g) * task_depth * lightening_factor)
task_b = base_b + int((255 - base_b) * task_depth * lightening_factor)
task_fill_color = f'#{min(255, task_r):02x}{min(255, task_g):02x}{min(255, task_b):02x}'
task_font_color = 'white' if task_depth * lightening_factor < 0.6 else 'black'
task_font_size = max(9, 14 - (task_depth * 2))
for task in tasks:
task_id = task.get('id')
task_label = task.get('label')
if not all([task_id, task_label]):
raise ValueError(f"Invalid task: {task}")
dot.node(
task_id,
task_label,
shape='box',
style='filled,rounded',
fillcolor=task_fill_color,
fontcolor=task_font_color,
fontsize=str(task_font_size)
)
dot.edge(phase_id, task_id, color='#4a4a4a', arrowhead='none') # Connect task to phase
# Use subgraph to enforce vertical alignment for tasks within a phase
# This makes columns in the WBS
if tasks: # Only create subgraph if there are tasks
with dot.subgraph(name=f'cluster_{phase_id}') as c:
c.attr(rank='same') # Try to keep tasks in same rank/column if possible
# Adding invisible nodes to link the phase to its tasks vertically
# This helps in aligning tasks under their phase header more cleanly.
c.node(phase_id)
for task in tasks:
c.node(task['id'])
# Save to temporary file
with NamedTemporaryFile(delete=False, suffix='.png') as tmp:
dot.render(tmp.name, format='png', cleanup=True)
return tmp.name + '.png'
except json.JSONDecodeError:
return "Error: Invalid JSON format"
except Exception as e:
return f"Error: {str(e)}"