import os from PIL import Image, ImageDraw, ImageFont import json import gradio as gr from google import genai from google.genai import types # Initialize Google Gemini client client = genai.Client(api_key=os.environ['GOOGLE_API_KEY']) model_name = "gemini-2.0-flash-exp" # Function to parse JSON output from Gemini def parse_json(json_output): """ Parse JSON output from the Gemini model. """ try: lines = json_output.splitlines() for i, line in enumerate(lines): if line == "```json": json_output = "\n".join(lines[i + 1:]) # Remove everything before "```json" json_output = json_output.split("```")[0] # Remove everything after the closing "```" break return json.loads(json_output) except Exception as e: print(f"Error parsing JSON: {e}") return {} # Function to draw a flowchart def draw_flowchart(image, flowchart_json): """ Draws a flowchart on the given image based on JSON input. """ im = image.copy() draw = ImageDraw.Draw(im) # Load default font try: font = ImageFont.load_default() except Exception as e: print(f"Error loading font: {e}") return im shapes = flowchart_json.get("shapes", []) connections = flowchart_json.get("connections", []) # Draw shapes for shape in shapes: x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] shape_type = shape.get("type", "rectangle").lower() label = shape.get("label", "") color = shape.get("color", "white") # Draw the shape if shape_type == "rectangle": draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) elif shape_type == "ellipse": draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) elif shape_type == "diamond": points = [ (x + w // 2, y), # Top (x + w, y + h // 2), # Right (x + w // 2, y + h), # Bottom (x, y + h // 2) # Left ] draw.polygon(points, fill=color, outline="black") # Calculate text position using getbbox bbox = font.getbbox(label) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] text_x = x + (w - text_w) // 2 text_y = y + (h - text_h) // 2 # Add the label draw.text((text_x, text_y), label, fill="black", font=font) # Draw connections for conn in connections: from_shape = next(s for s in shapes if s["id"] == conn["from"]) to_shape = next(s for s in shapes if s["id"] == conn["to"]) x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] # Draw the line draw.line([x1, y1, x2, y2], fill="black", width=2) # Add arrowhead for arrows if conn.get("type", "arrow") == "arrow": arrow_size = 10 draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") return im # Function to draw a flowchart # Function to draw a flowchart def olddraw_flowchart(image, flowchart_json): """ Draws a flowchart on the given image based on JSON input. """ im = image.copy() draw = ImageDraw.Draw(im) # Load default font try: font = ImageFont.load_default() except Exception as e: print(f"Error loading font: {e}") return im shapes = flowchart_json.get("shapes", []) connections = flowchart_json.get("connections", []) # Draw shapes for shape in shapes: x, y, w, h = shape["x"], shape["y"], shape["width"], shape["height"] shape_type = shape.get("type", "rectangle").lower() label = shape.get("label", "") color = shape.get("color", "white") # Draw the shape if shape_type == "rectangle": draw.rectangle([x, y, x + w, y + h], fill=color, outline="black", width=3) elif shape_type == "ellipse": draw.ellipse([x, y, x + w, y + h], fill=color, outline="black", width=3) elif shape_type == "diamond": points = [ (x + w // 2, y), # Top (x + w, y + h // 2), # Right (x + w // 2, y + h), # Bottom (x, y + h // 2) # Left ] draw.polygon(points, fill=color, outline="black") # Calculate text position text_w, text_h = font.getsize(label) text_x = x + (w - text_w) // 2 text_y = y + (h - text_h) // 2 # Add the label draw.text((text_x, text_y), label, fill="black", font=font) # Draw connections for conn in connections: from_shape = next(s for s in shapes if s["id"] == conn["from"]) to_shape = next(s for s in shapes if s["id"] == conn["to"]) x1, y1 = from_shape["x"] + from_shape["width"] // 2, from_shape["y"] + from_shape["height"] x2, y2 = to_shape["x"] + to_shape["width"] // 2, to_shape["y"] # Draw the line draw.line([x1, y1, x2, y2], fill="black", width=2) # Add arrowhead for arrows if conn.get("type", "arrow") == "arrow": arrow_size = 10 draw.polygon([(x2, y2 - arrow_size), (x2, y2 + arrow_size), (x2 + arrow_size, y2)], fill="black") return im # Function to generate flowchart JSON via Gemini def generate_flowchart(prompt): """ Use Google Gemini to generate JSON for a flowchart. """ try: response = client.models.generate_content( model=model_name, contents=[prompt], config=types.GenerateContentConfig( system_instruction=""" Return a JSON structure describing a flowchart. Use formal flowchart conventions with shapes like rectangles, ellipses, and diamonds. Each shape should have attributes: id, label, x, y, width, height, type (e.g., 'rectangle', 'ellipse', 'diamond'), and color. Also include connections with attributes: from (id), to (id), and type (e.g., 'arrow'). """, temperature=0.5, ) ) print("Gemini Response:", response.text) return parse_json(response.text) except Exception as e: print(f"Error generating flowchart JSON: {e}") return {} # Function to predict the flowchart def predict_flowchart(prompt): """ Generate a flowchart image based on the user's prompt. """ try: # Generate the flowchart JSON flowchart_json = generate_flowchart(prompt) if not flowchart_json: raise ValueError("Could not generate flowchart JSON.") # Create a blank image to draw on image = Image.new("RGB", (1000, 800), "white") result_image = draw_flowchart(image, flowchart_json) return result_image except Exception as e: print(f"Error during processing: {e}") # Return a blank image in case of an error error_image = Image.new("RGB", (1000, 800), "white") draw = ImageDraw.Draw(error_image) draw.text((50, 50), f"Error: {str(e)}", fill="red") return error_image # Define the Gradio interface for flowcharts def gradio_interface_flowcharts(): """ Gradio app interface for flowchart generation. """ with gr.Blocks(gr.themes.Glass(secondary_hue="blue")) as demo: gr.Markdown("# Flowchart Generator with Gemini") with gr.Row(): with gr.Column(): gr.Markdown("### Input Section") input_prompt = gr.Textbox(lines=2, label="Input Prompt", placeholder="Describe the flowchart process.") submit_btn = gr.Button("Generate Flowchart") with gr.Column(): gr.Markdown("### Output Section") output_image = gr.Image(type="pil", label="Output Flowchart") # Event to generate flowcharts submit_btn.click( predict_flowchart, inputs=[input_prompt], outputs=[output_image] ) return demo # Run the app if __name__ == "__main__": demo = gradio_interface_flowcharts() demo.launch()