File size: 4,553 Bytes
e629bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae7526
e629bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae7526
17233cd
5ae7526
e629bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae7526
 
e629bd5
6a7243e
e629bd5
2991f27
e629bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a7243e
e629bd5
 
 
 
5ae7526
e629bd5
 
 
 
 
 
 
 
5be508e
e629bd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ae7526
6a7243e
e629bd5
 
 
 
5ae7526
e629bd5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import matplotlib.pyplot as plt
import io
import base64
from PIL import Image
from groq import Groq

# Define the function to encode the image
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

# Function to run the user input code and display the plot
def run_code(code, groq_api_key):
    try:
        # Setup Groq API client with the provided API key
        llamavision_client = Groq(api_key=groq_api_key)
        llama_client = Groq(api_key=groq_api_key)

        fig, ax = plt.subplots()

        # Create a safe environment to execute code
        exec(code, {"plt": plt, "ax": ax})

        # Save the plot to a byte buffer
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight")
        buf.seek(0)

        # Open the saved image and resize it if necessary
        img = Image.open(buf)
        max_width, max_height = 600, 400  # Set maximum display size
        img.thumbnail((max_width, max_height))  # Resize to fit

        # Save the image to the disk
        img.save("plot.png")
        buf.seek(0)

        # Encode the image for Groq API
        base64_image = encode_image("plot.png")

        # Sending the plot image to Llama Vision API to get the description
        llamavision_completion = llamavision_client.chat.completions.create(
            model="llama-3.2-11b-vision-preview",
            messages=[{
                "role": "user",
                "content": [
                    {"type": "text", "text": f"Describe the plot values with image and the code provided to you. Code: {code}\n"},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{base64_image}"}
                    },
                ],
            }],
            temperature=0.5,
            max_tokens=4096,
            top_p=1,
            stream=False,
            stop=None,
        )

        # Extract the Llama Vision description from the API response
        llamavision_description = llamavision_completion.choices[0].message.content

        # Sending the plot image to Llama 3.2 API to get the description
        llama_completion = llama_client.chat.completions.create(
            model="llama-3.2-90b-text-preview",
            messages=[
                {
                    "role": "system",
                    "content": "What are the details of the plot provided by the user. Point out the important things in the dataset. What is the purpose of this dataset and how to interpret it. Analyze it."
                },
                {
                    "role": "user",
                    "content": code
                }
            ],
            temperature=0,
            max_tokens=4096,
            top_p=1,
            stream=True,
            stop=None,
        )

        # Extract the Llama 3.2 description from the API response
        llama_description = ""
        for chunk in llama_completion:
            llama_description += chunk.choices[0].delta.content or ""

        return img, llamavision_description, llama_description

    except Exception as e:
        return None, f"Error: {str(e)}", None
    finally:
        plt.close(fig)

# Define the Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""## Plot and Describe - Inference Powered by [Groq](https://groq.com/)

    **⚠️ Disclaimer:** Generative models may hallucinate or produce incorrect outputs. This tool is built for demonstration purposes only and should not be relied upon for critical analysis or decision-making.
    """, elem_id="disclaimer")

    with gr.Row():
        api_key_input = gr.Textbox(
            label="Groq API Key", type="password", placeholder="Enter your Groq API key here"
        )

    with gr.Row():
        code_input = gr.Code(
            language="python", lines=20, label="Input Code"
        )
        output_image = gr.Image(type="pil", label="Chart will be displayed here")

    submit_btn = gr.Button("Submit")

    with gr.Row():
        output_llamavision_text = gr.Textbox(label="Description from Llama 3.2 Vision", interactive=False)
        output_llama_text = gr.Textbox(label="Description from Llama 3.2 Text", interactive=False)

    submit_btn.click(
        fn=run_code,
        inputs=[code_input, api_key_input],
        outputs=[output_image, output_llamavision_text, output_llama_text]
    )

# Launch the interface
demo.launch()