conzchunglfxsdu's picture
5c765b5 verified
import os
import base64
import json
import pymongo
from typing import List, Optional, Dict, Any, Tuple
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from langchain_community.llms import HuggingFaceEndpoint
import gradio as gr
from pymongo import MongoClient
from bson import ObjectId
import asyncio
from PIL import Image, ImageOps
from aiohttp.client_exceptions import ClientResponseError
MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
client = MongoClient(MONGOCONN)
db = client["hf-log"] # Database name
collection = db["image_tagging_space"] # Collection name
img_spec_token = "<|im_image|>"
img_join_token = "<|and|>"
sos_token = "<|im_start|>"
eos_token = "<|im_end|>"
# Function to resize image
def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
img =
img.thumbnail((max_width, max_height), Image.LANCZOS)
resized_image_path = f"/tmp/{os.path.basename(image_path)}"
return resized_image_path
# Function to encode images to Base64
def encode_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode("utf-8")
# Generate prompt from images using empty tokens
def img_to_prompt(images: List[str]) -> str:
encoded_images = [encode_image_to_base64(img) for img in images]
return img_spec_token + img_join_token.join(encoded_images) + img_spec_token
# Combine image and text prompts using empty tokens
def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
system_prompt = sos_token + f"system\n{ai_role}" + eos_token
user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
user_prompt += "assistant\n"
return system_prompt + user_prompt
def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
return [(user_input, response) for user_input, response in history]
async def call_inference(user_prompt):
endpoint_url = ""
llm = HuggingFaceEndpoint(endpoint_url=endpoint_url,
response = await llm._acall(user_prompt)
except ClientResponseError as e:
return f"API call failed: {e.message}"
return response
async def submit(message, history, doc_ids, last_image):
# Log the user message and files
print("User Message:", message["text"])
print("User Files:", message["files"])
image = None
image_filetype = None
if message["files"]:
image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
image_filetype = os.path.splitext(image)[1].lower()
#image = resize_image(image)
last_image = (image, image_filetype)
image, image_filetype = last_image
if not image:
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)
human_prompt = message['text']
img_prompt = img_to_prompt([image])
user_prompt = combine_img_with_text(img_prompt, human_prompt)
# Return user input immediately
history.append((human_prompt, "<processing>"))
outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
# Call inference asynchronously
response = await call_inference(user_prompt)
selected_output = response.split("assistant\n")[-1].strip()
# Store the message, image prompt, response, and image file type in MongoDB
document = {
'image_prompt': img_prompt,
'user_prompt': human_prompt,
'response': selected_output,
'image_filetype': image_filetype,
'likes': 0,
'dislikes': 0,
'like_dislike_reason': None
result = collection.insert_one(document)
document_id = str(result.inserted_id)
# Log the storage in MongoDB
print(f"Stored in MongoDB with ID: {document_id}")
# Update the chat history and document IDs
history[-1] = (human_prompt, selected_output)
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
if not history:
index = x.index[0] if isinstance(x.index, list) else x.index
document_id = doc_ids[index]
update_field = "likes" if x.liked else "dislikes"
collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")
def submit_reason_only(doc_ids, reason, selected_index, history):
if selected_index is None:
selected_index = len(history) - 1 # Select the last message if no message is selected
document_id = doc_ids[selected_index]
{"_id": ObjectId(document_id)},
{"$set": {"like_dislike_reason": reason}}
print(f"Document ID: {document_id}, Reason submitted: {reason}")
return f"Reason submitted."
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA NeXT 34B-ft-v3 LFX</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is finetuned by LFX</p>
with gr.Blocks(fill_height=True) as demo:
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Column(scale=1):
image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
submit_reason_btn = gr.Button("Submit Reason", visible=True)
history_state = gr.State([])
doc_ids_state = gr.State([])
last_image_state = gr.State((None, None))
selected_index_state = gr.State(None) # Initializing the state
def select_message(evt: gr.SelectData, history, doc_ids):
selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
print(f"Selected Index: {selected_index}") # Debugging print statement
return gr.update(visible=True), selected_index
chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display]), inputs=[history_state, doc_ids_state, reason_box], outputs=[]), inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state
demo.launch(show_api=False, share=True, debug=True)