linm1 commited on
Commit
1e3ede4
·
verified ·
1 Parent(s): 1a7b044

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +163 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from dotenv import load_dotenv
4
+ import base64
5
+ from io import BytesIO
6
+ from mistralai import Mistral
7
+ from pydantic import BaseModel, Field
8
+ from datasets import load_dataset
9
+ from PIL import Image
10
+ import json
11
+ import sqlite3
12
+ from datetime import datetime
13
+
14
+ # Load the dataset
15
+ ds = load_dataset("svjack/pokemon-blip-captions-en-zh")
16
+ ds = ds["train"]
17
+
18
+ # Load environment variables
19
+ api_key = os.environ.get('MISTRAL_API_KEY')
20
+
21
+ if not api_key:
22
+ raise ValueError("MISTRAL_API_KEY is not set in the environment variables.")
23
+
24
+ # Create sample history
25
+ hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)]
26
+ hist_str = "\n".join(hist)
27
+
28
+ # Define the Caption model
29
+ class Caption(BaseModel):
30
+ en: str = Field(...,
31
+ description="English caption of image",
32
+ max_length=84)
33
+ zh: str = Field(...,
34
+ description="Chinese caption of image",
35
+ max_length=64)
36
+
37
+ # Initialize the Mistral client
38
+ client = Mistral(api_key=api_key)
39
+
40
+ def generate_caption(image):
41
+ # Convert image to base64
42
+ buffered = BytesIO()
43
+ image.save(buffered, format="JPEG")
44
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
45
+
46
+ messages = [
47
+ {
48
+ "role": "system",
49
+ "content": f'''
50
+ You are a highly accurate image to caption transformer.
51
+ Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR!
52
+ Do NOT provide NAMES! KEEP it SHORT!
53
+ While adhering to the following JSON schema: {Caption.model_json_schema()}
54
+ Following are some samples you should adhere to for style and tone:
55
+ {hist_str}
56
+ '''
57
+ },
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {
62
+ "type": "text",
63
+ "text": "Describe the image in English and Chinese"
64
+ },
65
+ {
66
+ "type": "image_url",
67
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
68
+ }
69
+ ]
70
+ }
71
+ ]
72
+
73
+ chat_response = client.chat.complete(
74
+ model="pixtral-12b-2409",
75
+ messages=messages,
76
+ response_format = {
77
+ "type": "json_object",
78
+ }
79
+ )
80
+
81
+ response_content = chat_response.choices[0].message.content
82
+
83
+ try:
84
+ caption_dict = json.loads(response_content)
85
+ return Caption(**caption_dict)
86
+ except json.JSONDecodeError as e:
87
+ print(f"Error decoding JSON: {e}")
88
+ return None
89
+
90
+ # Initialize SQLite database
91
+ def init_db():
92
+ conn = sqlite3.connect('feedback.db')
93
+ c = conn.cursor()
94
+ c.execute('''CREATE TABLE IF NOT EXISTS thumbs_up
95
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
96
+ timestamp TEXT,
97
+ input_data TEXT,
98
+ output_data TEXT)''')
99
+ conn.commit()
100
+ conn.close()
101
+
102
+ init_db()
103
+
104
+ def process_image(image):
105
+ if image is None:
106
+ return "Please upload an image first."
107
+
108
+ result = generate_caption(image)
109
+
110
+ if result:
111
+ return f"English caption: {result.en}\nChinese caption: {result.zh}"
112
+ else:
113
+ return "Failed to generate caption. Please check the API call or network connectivity."
114
+
115
+ def thumbs_up(image, caption):
116
+ # Convert image to base64 string for storage
117
+ buffered = BytesIO()
118
+ image.save(buffered, format="JPEG")
119
+ img_str = base64.b64encode(buffered.getvalue()).decode()
120
+
121
+ conn = sqlite3.connect('feedback.db')
122
+ c = conn.cursor()
123
+ c.execute("INSERT INTO thumbs_up (timestamp, input_data, output_data) VALUES (?, ?, ?)",
124
+ (datetime.now().isoformat(), img_str, caption))
125
+ conn.commit()
126
+ conn.close()
127
+ print(f"Thumbs up data saved to database.")
128
+ return gr.Notification("Thank you for your feedback!", type="success")
129
+
130
+ # Create Gradio interface
131
+ custom_css = """
132
+ .highlight-btn {
133
+ background-color: #3498db !important;
134
+ border-color: #3498db !important;
135
+ color: white !important;
136
+ }
137
+ .highlight-btn:hover {
138
+ background-color: #2980b9 !important;
139
+ border-color: #2980b9 !important;
140
+ }
141
+ """
142
+
143
+ with gr.Blocks() as iface:
144
+ gr.Markdown("# Image Captioner")
145
+ gr.Markdown("Upload an image to generate captions in English and Chinese. Use the 'Thumbs Up' button if you like the result!")
146
+
147
+ with gr.Row():
148
+ with gr.Column(scale=1):
149
+ input_image = gr.Image(type="pil")
150
+ with gr.Row():
151
+ clear_btn = gr.Button("Clear")
152
+ submit_btn = gr.Button("Submit", elem_classes=["highlight-btn"])
153
+
154
+ with gr.Column(scale=1):
155
+ output_text = gr.Textbox()
156
+ thumbs_up_btn = gr.Button("Thumbs Up")
157
+
158
+ clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image)
159
+ submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text)
160
+ thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None)
161
+
162
+ # Launch the interface
163
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ anthropic
2
+ openai>=1.1.0
3
+ mistralai
4
+ pydantic
5
+ docstring-parser
6
+ rich
7
+ aiohttp
8
+ ruff==0.1.7
9
+ pre-commit==3.5.0
10
+ pyright==1.1.360
11
+ typer
12
+ cohere
13
+ datasets
14
+ gradio
15
+ Pillow