File size: 6,373 Bytes
687c335
 
 
 
 
 
202de3a
687c335
 
 
 
 
 
 
 
 
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687c335
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
687c335
 
202de3a
 
 
687c335
202de3a
 
687c335
202de3a
687c335
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dab1246
687c335
202de3a
687c335
 
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687c335
202de3a
 
 
 
 
 
 
 
 
687c335
202de3a
 
 
 
 
 
687c335
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687c335
 
202de3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687c335
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import base64
import markdown
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv
from typing import List, Dict

load_dotenv()
XAI_API_KEY = os.getenv("XAI_API_KEY")

client = OpenAI(
    api_key=XAI_API_KEY,
    base_url="https://api.x.ai/v1",
)

#I will try out system prompts and change it later
def build_system_prompt() -> dict:
    return {
        "role": "system",
        "content": (
            "You are Grok Vision, created by xAI. You're designed to understand and describe images and answer text-based queries. "
            "Use all previous conversation context to provide clear, positive, and helpful responses. "
            "Respond in markdown format when appropriate."
        )
    }

def encode_image(image_path: str) -> str:
    file_size = os.path.getsize(image_path)
    if file_size > 10 * 1024 * 1024:
        raise ValueError("Image exceeds maximum size of 10MB.")
    ext = os.path.splitext(image_path)[1].lower()
    if ext in ['.jpg', '.jpeg']:
        mime_type = 'image/jpeg'
    elif ext == '.png':
        mime_type = 'image/png'
    else:
        raise ValueError("Unsupported image format. Only JPEG and PNG are allowed.")
    #Encodes a local image file to base64 which only supports
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
    return f"data:{mime_type};base64,{encoded_string}"

def process_input(user_text: str, user_image_paths: List[str]) -> tuple[str, List[str]]:
    user_text = user_text.strip() if user_text else ""
    image_urls = []
    # Extract URLs
    text_parts = user_text.split()
    remaining_text = []
    for part in text_parts:
        if part.startswith("http"):
            image_urls.append(part)
        else:
            remaining_text.append(part)
    user_text = " ".join(remaining_text) if remaining_text else ""
    if user_image_paths:
        for path in user_image_paths:
            if path: 
                image_urls.append(encode_image(path))
    
    return user_text, image_urls

def create_message_content(text: str, image_urls: List[str]) -> list[dict]:
    content = []
    for image_url in image_urls:
        content.append({
            "type": "image_url",
            "image_url": {
                "url": image_url, 
                "detail": "high"
            }
        })
    if text:
        content.append({
            "type": "text",
            "text": text
        })
    return content

def stream_response(history: List[Dict], user_text: str, user_image_paths: List[str]):
    user_text, image_urls = process_input(user_text, user_image_paths)
    if not user_text and not image_urls:
        history.append({"role": "assistant", "content": "Please provide text or at least one image (JPEG/PNG only)."})
        yield history
        return
    messages = [build_system_prompt()]
    for entry in history:
        if entry["role"] == "user":
            content = create_message_content(entry["content"], entry.get("image_urls", []))
            messages.append({"role": "user", "content": content})
        elif entry["role"] == "assistant":
            messages.append({"role": "assistant", "content": entry["content"]})
    new_content = create_message_content(user_text, image_urls)
    messages.append({"role": "user", "content": new_content})
    history.append({"role": "user", "content": user_text, "image_urls": image_urls})
    stream = client.chat.completions.create(
        model="grok-2-vision-1212",
        messages=messages,
        stream=True,
        temperature=0.01,
    )
    response_text = ""
    temp_history = history.copy()
    temp_history.append({"role": "assistant", "content": ""})
    for chunk in stream:
        delta_content = chunk.choices[0].delta.content
        if delta_content is not None:
            response_text += delta_content
            temp_history[-1] = {"role": "assistant", "content": response_text}
            yield temp_history

def clear_inputs_and_chat():
    return [], [], "", None 

def update_and_clear(history: List[Dict], streamed_response: List[Dict]) -> tuple[List[Dict], str, None]:
    if streamed_response and history[-1]["content"] != streamed_response[-1]["content"]:
        history[-1] = streamed_response[-1]
    return history, "", None  

with gr.Blocks(
    theme=gr.themes.Soft(),
    css="""
        .chatbot-container {max-height: 80vh; overflow-y: auto;}
        .input-container {margin-top: 20px;}
        .title {text-align: center; margin-bottom: 20px;}
    """
) as demo:
    gr.Markdown(
        """
        # Grok 2 Vision Chatbot 𝕏
        
        Interact with Grok 2 Vision you can do:
        - πŸ“Έ Upload one or more images (Max 10MB each)
        - πŸ”— Provide image URLs in your message (`https://example.com/image1.jpg)
        - ✍️ Ask text-only questions
        - πŸ’¬ Chat history is preserved.
        """
    )
    
    with gr.Column(elem_classes="chatbot-container"):
        chatbot = gr.Chatbot(
            label="Conversation",
            type="messages",
            bubble_full_width=False
        )
    
    with gr.Row(elem_classes="input-container"):
        with gr.Column(scale=1):
            image_input = gr.File(
                file_count="multiple", 
                file_types=[".jpg", ".jpeg", ".png"], 
                label="Upload JPEG or PNG Images",
                height=300,
                interactive=True
            )
        with gr.Column(scale=3):
            message_input = gr.Textbox(
                label="Your Message",
                placeholder="Type your question or paste JPEG/PNG image URLs",
                lines=3
            )
            with gr.Row():
                submit_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")
    
    state = gr.State([])

    submit_btn.click(
        fn=stream_response,
        inputs=[state, message_input, image_input],
        outputs=chatbot,
        queue=True
    ).then(
        fn=update_and_clear,
        inputs=[state, chatbot],
        outputs=[state, message_input, image_input]
    )
    
    clear_btn.click(
        fn=clear_inputs_and_chat,
        inputs=[],
        outputs=[chatbot, state, message_input, image_input]
    )

if __name__ == "__main__":
    demo.launch()