conzchunglfxsdu's picture
Update app.py
5c765b5 verified
import os
import base64
import json
import pymongo
from typing import List, Optional, Dict, Any, Tuple
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from langchain_community.llms import HuggingFaceEndpoint
import gradio as gr
from pymongo import MongoClient
from bson import ObjectId
import asyncio
from PIL import Image, ImageOps
from aiohttp.client_exceptions import ClientResponseError
MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
client = MongoClient(MONGOCONN)
db = client["hf-log"] # Database name
collection = db["image_tagging_space"] # Collection name
img_spec_token = "<|im_image|>"
img_join_token = "<|and|>"
sos_token = "<|im_start|>"
eos_token = "<|im_end|>"
# Function to resize image
def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
img = Image.open(image_path)
img.thumbnail((max_width, max_height), Image.LANCZOS)
resized_image_path = f"/tmp/{os.path.basename(image_path)}"
img.save(resized_image_path)
return resized_image_path
# Function to encode images to Base64
def encode_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
# Generate prompt from images using empty tokens
def img_to_prompt(images: List[str]) -> str:
encoded_images = [encode_image_to_base64(img) for img in images]
return img_spec_token + img_join_token.join(encoded_images) + img_spec_token
# Combine image and text prompts using empty tokens
def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
system_prompt = sos_token + f"system\n{ai_role}" + eos_token
user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
user_prompt += "assistant\n"
return system_prompt + user_prompt
def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
return [(user_input, response) for user_input, response in history]
async def call_inference(user_prompt):
endpoint_url = "https://rn65ru6q35e05iu0.us-east-1.aws.endpoints.huggingface.cloud"
llm = HuggingFaceEndpoint(endpoint_url=endpoint_url,
max_new_tokens=2000,
temperature=0.1,
do_sample=True,
use_cache=True,
timeout=300)
try:
response = await llm._acall(user_prompt)
except ClientResponseError as e:
return f"API call failed: {e.message}"
return response
async def submit(message, history, doc_ids, last_image):
# Log the user message and files
print("User Message:", message["text"])
print("User Files:", message["files"])
image = None
image_filetype = None
if message["files"]:
image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
image_filetype = os.path.splitext(image)[1].lower()
#image = resize_image(image)
last_image = (image, image_filetype)
else:
image, image_filetype = last_image
if not image:
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)
human_prompt = message['text']
img_prompt = img_to_prompt([image])
user_prompt = combine_img_with_text(img_prompt, human_prompt)
# Return user input immediately
history.append((human_prompt, "<processing>"))
outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
# Call inference asynchronously
response = await call_inference(user_prompt)
selected_output = response.split("assistant\n")[-1].strip()
# Store the message, image prompt, response, and image file type in MongoDB
document = {
'image_prompt': img_prompt,
'user_prompt': human_prompt,
'response': selected_output,
'image_filetype': image_filetype,
'likes': 0,
'dislikes': 0,
'like_dislike_reason': None
}
result = collection.insert_one(document)
document_id = str(result.inserted_id)
# Log the storage in MongoDB
print(f"Stored in MongoDB with ID: {document_id}")
# Update the chat history and document IDs
history[-1] = (human_prompt, selected_output)
doc_ids.append(document_id)
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
if not history:
return
index = x.index[0] if isinstance(x.index, list) else x.index
document_id = doc_ids[index]
update_field = "likes" if x.liked else "dislikes"
collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")
def submit_reason_only(doc_ids, reason, selected_index, history):
if selected_index is None:
selected_index = len(history) - 1 # Select the last message if no message is selected
document_id = doc_ids[selected_index]
collection.update_one(
{"_id": ObjectId(document_id)},
{"$set": {"like_dislike_reason": reason}}
)
print(f"Document ID: {document_id}, Reason submitted: {reason}")
return f"Reason submitted."
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA NeXT 34B-ft-v3 LFX</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is finetuned by LFX</p>
</div>
"""
with gr.Blocks(fill_height=True) as demo:
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Column(scale=1):
image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
submit_reason_btn = gr.Button("Submit Reason", visible=True)
history_state = gr.State([])
doc_ids_state = gr.State([])
last_image_state = gr.State((None, None))
selected_index_state = gr.State(None) # Initializing the state
def select_message(evt: gr.SelectData, history, doc_ids):
selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
print(f"Selected Index: {selected_index}") # Debugging print statement
return gr.update(visible=True), selected_index
chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display])
chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[])
chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state
submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state
demo.queue(api_open=False)
demo.launch(show_api=False, share=True, debug=True)