from fastapi import FastAPI, Response
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import graphviz

app = FastAPI()

# Initialize the inference client for the AI model
client = InferenceClient("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")

class CourseRequest(BaseModel):
    course_name: str

def format_prompt(course_name: str):
    return f"As an expert in education, please generate a detailed roadmap for the course '{course_name}'. Include key topics."

def generate_roadmap(item: CourseRequest):
    prompt = format_prompt(item.course_name)
    stream = client.text_generation(prompt, max_new_tokens=200)
    output = ""

    for response in stream:
        output += response.token.text

    return output

def create_diagram(roadmap_text: str):
    dot = graphviz.Digraph()

    # Split the roadmap text into lines or sections for diagram creation
    lines = roadmap_text.split('\n')
    for i, line in enumerate(lines):
        dot.node(str(i), line.strip())  # Create a node for each topic

        if i > 0:
            dot.edge(str(i - 1), str(i))  # Connect nodes sequentially

    return dot

@app.post("/generate/")
async def generate_roadmap_endpoint(course_request: CourseRequest):
    roadmap_text = generate_roadmap(course_request)
    diagram = create_diagram(roadmap_text)

    # Render the diagram to a PNG image
    diagram_path = "/tmp/roadmap"
    diagram.render(diagram_path, format='png', cleanup=True)

    with open(diagram_path + ".png", "rb") as f:
        return Response(content=f.read(), media_type="image/png")