File size: 3,763 Bytes
0537a74
 
 
 
 
 
11511b6
61ea1a8
0537a74
 
 
 
 
 
 
 
61ea1a8
 
 
0537a74
 
 
 
 
 
552281c
0537a74
 
11511b6
777550c
 
506b0cf
11511b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777550c
858ec00
61ea1a8
 
 
 
777550c
61ea1a8
 
 
 
 
 
 
 
777550c
 
 
 
 
 
 
 
 
 
 
3d6604f
 
1a23a7e
 
 
 
 
 
 
 
0537a74
11511b6
3d6604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import re
from groq import Groq

# Initialize FastAPI app
app = FastAPI()

# Serve static files for assets
app.mount("/static", StaticFiles(directory="static"), name="static")

# Initialize Hugging Face Inference Client
#client = InferenceClient()

client = Groq()

# Pydantic model for API input
class InfographicRequest(BaseModel):
    description: str

# Load prompt template from environment variable
SYSTEM_INSTRUCT = os.getenv("SYSTEM_INSTRUCTOR")
PROMPT_TEMPLATE = os.getenv("PROMPT_TEMPLATE")




async def extract_code_blocks(markdown_text):
    """
    Extracts code blocks from the given Markdown text.

    Args:
        markdown_text (str): The Markdown content as a string.

    Returns:
        list: A list of code blocks extracted from the Markdown.
    """
    # Regex to match code blocks (fenced with triple backticks)
    code_block_pattern = re.compile(r'```.*?\n(.*?)```', re.DOTALL)

    # Find all code blocks
    code_blocks = code_block_pattern.findall(markdown_text)
    
    return code_blocks

async def generate_infographic_details(request: InfographicRequest):
    description = request.description
    generated_completion = client.chat.completions.create(
            model="llama-3.1-70b-versatile",
            messages=[
                {"role": "system", "content": SYSTEM_INSTRUCT},
                {"role": "user", "content": description}
            ],
            temperature=0.5,
            max_tokens=5000,
            top_p=1,
            stream=False,
            stop=None
        )
    generated_text = generated_completion.choices[0].message.content

# Route to serve the HTML template
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
    return HTMLResponse(open("static/infographic_gen.html").read())

# Route to handle infographic generation
@app.post("/generate")
async def generate_infographic(request: InfographicRequest):
    description =await generate_infographic_details(request)
    prompt = PROMPT_TEMPLATE.format(description=description)

    messages = [{"role": "user", "content": prompt}]
    stream = client.chat.completions.create(
        model="Qwen/Qwen2.5-Coder-32B-Instruct",
        messages=messages,
        temperature=0.4,
        max_tokens=6000,
        top_p=0.7,
        stream=True,
    )

        
    generated_text = ""
    for chunk in stream:
        generated_text += chunk.choices[0].delta.content

    print(generated_text)
    code_blocks= await extract_code_blocks(generated_text)
    if code_blocks:
        return JSONResponse(content={"html": code_blocks[0]})
    else:
        return JSONResponse(content={"error": "No generation"},status_code=500)
    
    # try:
    #     messages = [{"role": "user", "content": prompt}]
    #     stream = client.chat.completions.create(
    #         model="Qwen/Qwen2.5-Coder-32B-Instruct",
    #         messages=messages,
    #         temperature=0.4,
    #         max_tokens=6000,
    #         top_p=0.7,
    #         stream=True,
    #     )

        
    #     generated_text = ""
    #     for chunk in stream:
    #         generated_text += chunk.choices[0].delta.content

    #     print(generated_text)
    #     code_blocks= await extract_code_blocks(generated_text)
    #     if code_blocks:
    #         return JSONResponse(content={"html": code_blocks[0]})
    #     else:
    #         return JSONResponse(content={"error": "No generation"},status_code=500)

    # except Exception as e:
    #     return JSONResponse(content={"error": str(e)}, status_code=500)