File size: 5,048 Bytes
302615c 1e3ede4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import gradio as gr
import os
import base64
from io import BytesIO
from mistralai import Mistral
from pydantic import BaseModel, Field
from datasets import load_dataset
from PIL import Image
import json
import sqlite3
from datetime import datetime
# Load the dataset
ds = load_dataset("svjack/pokemon-blip-captions-en-zh")
ds = ds["train"]
# Load environment variables
api_key = os.environ.get('MISTRAL_API_KEY')
if not api_key:
raise ValueError("MISTRAL_API_KEY is not set in the environment variables.")
# Create sample history
hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)]
hist_str = "\n".join(hist)
# Define the Caption model
class Caption(BaseModel):
en: str = Field(...,
description="English caption of image",
max_length=84)
zh: str = Field(...,
description="Chinese caption of image",
max_length=64)
# Initialize the Mistral client
client = Mistral(api_key=api_key)
def generate_caption(image):
# Convert image to base64
buffered = BytesIO()
image.save(buffered, format="JPEG")
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
messages = [
{
"role": "system",
"content": f'''
You are a highly accurate image to caption transformer.
Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR!
Do NOT provide NAMES! KEEP it SHORT!
While adhering to the following JSON schema: {Caption.model_json_schema()}
Following are some samples you should adhere to for style and tone:
{hist_str}
'''
},
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image in English and Chinese"
},
{
"type": "image_url",
"image_url": f"data:image/jpeg;base64,{base64_image}"
}
]
}
]
chat_response = client.chat.complete(
model="pixtral-12b-2409",
messages=messages,
response_format = {
"type": "json_object",
}
)
response_content = chat_response.choices[0].message.content
try:
caption_dict = json.loads(response_content)
return Caption(**caption_dict)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
return None
# Initialize SQLite database
def init_db():
conn = sqlite3.connect('feedback.db')
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS thumbs_up
(id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT,
input_data TEXT,
output_data TEXT)''')
conn.commit()
conn.close()
init_db()
def process_image(image):
if image is None:
return "Please upload an image first."
result = generate_caption(image)
if result:
return f"English caption: {result.en}\nChinese caption: {result.zh}"
else:
return "Failed to generate caption. Please check the API call or network connectivity."
def thumbs_up(image, caption):
# Convert image to base64 string for storage
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
conn = sqlite3.connect('feedback.db')
c = conn.cursor()
c.execute("INSERT INTO thumbs_up (timestamp, input_data, output_data) VALUES (?, ?, ?)",
(datetime.now().isoformat(), img_str, caption))
conn.commit()
conn.close()
print(f"Thumbs up data saved to database.")
return gr.Notification("Thank you for your feedback!", type="success")
# Create Gradio interface
custom_css = """
.highlight-btn {
background-color: #3498db !important;
border-color: #3498db !important;
color: white !important;
}
.highlight-btn:hover {
background-color: #2980b9 !important;
border-color: #2980b9 !important;
}
"""
with gr.Blocks() as iface:
gr.Markdown("# Image Captioner")
gr.Markdown("Upload an image to generate captions in English and Chinese. Use the 'Thumbs Up' button if you like the result!")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", elem_classes=["highlight-btn"])
with gr.Column(scale=1):
output_text = gr.Textbox()
thumbs_up_btn = gr.Button("Thumbs Up")
clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image)
submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text)
thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None)
# Launch the interface
iface.launch(share=True) |