Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +181 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import base64
|
3 |
+
import json
|
4 |
+
import pymongo
|
5 |
+
from typing import List, Optional, Dict, Any, Tuple
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
8 |
+
from langchain_community.llms import HuggingFaceEndpoint
|
9 |
+
import gradio as gr
|
10 |
+
from pymongo import MongoClient
|
11 |
+
from bson import ObjectId
|
12 |
+
import asyncio
|
13 |
+
from PIL import Image, ImageOps
|
14 |
+
from aiohttp.client_exceptions import ClientResponseError
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
|
19 |
+
client = MongoClient(MONGOCONN)
|
20 |
+
db = client["hf-log"] # Database name
|
21 |
+
collection = db["image_tagging_space"] # Collection name
|
22 |
+
|
23 |
+
img_spec_token = "<|im_image|>"
|
24 |
+
img_join_token = "<|and|>"
|
25 |
+
sos_token = "[INST]"
|
26 |
+
eos_token = "[/INST]"
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
ASPECT_RATIOS = [(1, 1), (1, 4), (4, 1)]
|
31 |
+
RESOLUTIONS = [(672, 672), (336, 1344), (1344, 336)]
|
32 |
+
|
33 |
+
# Function to resize image
|
34 |
+
def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
|
35 |
+
img = Image.open(image_path)
|
36 |
+
img.thumbnail((max_width, max_height), Image.LANCZOS)
|
37 |
+
resized_image_path = f"/tmp/{os.path.basename(image_path)}"
|
38 |
+
img.save(resized_image_path)
|
39 |
+
return resized_image_path
|
40 |
+
|
41 |
+
# Function to encode images to Base64
|
42 |
+
def encode_image_to_base64(image_path: str) -> str:
|
43 |
+
with open(image_path, "rb") as image_file:
|
44 |
+
return base64.b64encode(image_file.read()).decode("utf-8")
|
45 |
+
|
46 |
+
# Generate prompt from images using empty tokens
|
47 |
+
def img_to_prompt(images: List[str]) -> str:
|
48 |
+
encoded_images = [encode_image_to_base64(img) for img in images]
|
49 |
+
return img_spec_token + img_join_token.join(encoded_images) + img_spec_token
|
50 |
+
|
51 |
+
# Combine image and text prompts using empty tokens
|
52 |
+
def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
|
53 |
+
system_prompt = sos_token + f"system\n{ai_role}" + eos_token
|
54 |
+
user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
|
55 |
+
user_prompt += "assistant\n"
|
56 |
+
return system_prompt + user_prompt
|
57 |
+
|
58 |
+
def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
59 |
+
return [(user_input, response) for user_input, response in history]
|
60 |
+
|
61 |
+
async def call_inference(user_prompt):
|
62 |
+
endpoint_url = "https://yzzwmsj8y9ji99i8.us-east-1.aws.endpoints.huggingface.cloud"
|
63 |
+
llm = HuggingFaceEndpoint(endpoint_url=endpoint_url,
|
64 |
+
max_new_tokens=2000,
|
65 |
+
temperature=0.1,
|
66 |
+
do_sample=True,
|
67 |
+
use_cache=True,
|
68 |
+
timeout=300)
|
69 |
+
try:
|
70 |
+
response = await llm._acall(user_prompt)
|
71 |
+
except ClientResponseError as e:
|
72 |
+
return f"API call failed: {e.message}"
|
73 |
+
return response
|
74 |
+
|
75 |
+
async def submit(message, history, doc_ids, last_image):
|
76 |
+
# Log the user message and files
|
77 |
+
print("User Message:", message["text"])
|
78 |
+
print("User Files:", message["files"])
|
79 |
+
|
80 |
+
image = None
|
81 |
+
image_filetype = None
|
82 |
+
if message["files"]:
|
83 |
+
image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
|
84 |
+
image_filetype = os.path.splitext(image)[1].lower()
|
85 |
+
image = resize_image(image)
|
86 |
+
last_image = (image, image_filetype)
|
87 |
+
else:
|
88 |
+
image, image_filetype = last_image
|
89 |
+
|
90 |
+
if not image:
|
91 |
+
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)
|
92 |
+
|
93 |
+
human_prompt = message['text']
|
94 |
+
img_prompt = img_to_prompt([image])
|
95 |
+
user_prompt = combine_img_with_text(img_prompt, human_prompt)
|
96 |
+
|
97 |
+
# Return user input immediately
|
98 |
+
history.append((human_prompt, "<processing>"))
|
99 |
+
outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
|
100 |
+
|
101 |
+
# Call inference asynchronously
|
102 |
+
response = await call_inference(user_prompt)
|
103 |
+
selected_output = response.split("assistant\n")[-1].strip()
|
104 |
+
|
105 |
+
# Store the message, image prompt, response, and image file type in MongoDB
|
106 |
+
document = {
|
107 |
+
'image_prompt': img_prompt,
|
108 |
+
'user_prompt': human_prompt,
|
109 |
+
'response': selected_output,
|
110 |
+
'image_filetype': image_filetype,
|
111 |
+
'likes': 0,
|
112 |
+
'dislikes': 0,
|
113 |
+
'like_dislike_reason': None
|
114 |
+
}
|
115 |
+
result = collection.insert_one(document)
|
116 |
+
document_id = str(result.inserted_id)
|
117 |
+
|
118 |
+
# Log the storage in MongoDB
|
119 |
+
print(f"Stored in MongoDB with ID: {document_id}")
|
120 |
+
|
121 |
+
# Update the chat history and document IDs
|
122 |
+
history[-1] = (human_prompt, selected_output)
|
123 |
+
doc_ids.append(document_id)
|
124 |
+
|
125 |
+
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
|
126 |
+
|
127 |
+
def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
|
128 |
+
if not history:
|
129 |
+
return
|
130 |
+
index = x.index[0] if isinstance(x.index, list) else x.index
|
131 |
+
document_id = doc_ids[index]
|
132 |
+
update_field = "likes" if x.liked else "dislikes"
|
133 |
+
collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
|
134 |
+
print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")
|
135 |
+
|
136 |
+
def submit_reason_only(doc_ids, reason, selected_index, history):
|
137 |
+
if selected_index is None:
|
138 |
+
selected_index = len(history) - 1 # Select the last message if no message is selected
|
139 |
+
document_id = doc_ids[selected_index]
|
140 |
+
collection.update_one(
|
141 |
+
{"_id": ObjectId(document_id)},
|
142 |
+
{"$set": {"like_dislike_reason": reason}}
|
143 |
+
)
|
144 |
+
print(f"Document ID: {document_id}, Reason submitted: {reason}")
|
145 |
+
return f"Reason submitted."
|
146 |
+
|
147 |
+
PLACEHOLDER = """
|
148 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
149 |
+
<img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
|
150 |
+
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-NeXT-Mistral-7B-LFX</h1>
|
151 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is hosted by LFX</p>
|
152 |
+
</div>
|
153 |
+
"""
|
154 |
+
|
155 |
+
with gr.Blocks(fill_height=True) as demo:
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column(scale=3):
|
158 |
+
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
|
159 |
+
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
|
160 |
+
with gr.Column(scale=1):
|
161 |
+
image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
|
162 |
+
reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
|
163 |
+
submit_reason_btn = gr.Button("Submit Reason", visible=True)
|
164 |
+
|
165 |
+
history_state = gr.State([])
|
166 |
+
doc_ids_state = gr.State([])
|
167 |
+
last_image_state = gr.State((None, None))
|
168 |
+
selected_index_state = gr.State(None) # Initializing the state
|
169 |
+
|
170 |
+
def select_message(evt: gr.SelectData, history, doc_ids):
|
171 |
+
selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
|
172 |
+
print(f"Selected Index: {selected_index}") # Debugging print statement
|
173 |
+
return gr.update(visible=True), selected_index
|
174 |
+
|
175 |
+
chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display])
|
176 |
+
chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[])
|
177 |
+
chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state
|
178 |
+
submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state
|
179 |
+
|
180 |
+
demo.queue(api_open=False)
|
181 |
+
demo.launch(show_api=False, share=True, debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
langchain
|
3 |
+
spaces
|
4 |
+
pillow
|
5 |
+
accelerate
|
6 |
+
transformers
|
7 |
+
gradio
|
8 |
+
pymongo
|