File size: 7,944 Bytes
f3f2264 5c765b5 f3f2264 0ca76d7 f3f2264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import os
import base64
import json
import pymongo
from typing import List, Optional, Dict, Any, Tuple
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from langchain_community.llms import HuggingFaceEndpoint
import gradio as gr
from pymongo import MongoClient
from bson import ObjectId
import asyncio
from PIL import Image, ImageOps
from aiohttp.client_exceptions import ClientResponseError
MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
client = MongoClient(MONGOCONN)
db = client["hf-log"] # Database name
collection = db["image_tagging_space"] # Collection name
img_spec_token = "<|im_image|>"
img_join_token = "<|and|>"
sos_token = "<|im_start|>"
eos_token = "<|im_end|>"
# Function to resize image
def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
img = Image.open(image_path)
img.thumbnail((max_width, max_height), Image.LANCZOS)
resized_image_path = f"/tmp/{os.path.basename(image_path)}"
img.save(resized_image_path)
return resized_image_path
# Function to encode images to Base64
def encode_image_to_base64(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
# Generate prompt from images using empty tokens
def img_to_prompt(images: List[str]) -> str:
encoded_images = [encode_image_to_base64(img) for img in images]
return img_spec_token + img_join_token.join(encoded_images) + img_spec_token
# Combine image and text prompts using empty tokens
def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
system_prompt = sos_token + f"system\n{ai_role}" + eos_token
user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
user_prompt += "assistant\n"
return system_prompt + user_prompt
def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
return [(user_input, response) for user_input, response in history]
async def call_inference(user_prompt):
endpoint_url = "https://rn65ru6q35e05iu0.us-east-1.aws.endpoints.huggingface.cloud"
llm = HuggingFaceEndpoint(endpoint_url=endpoint_url,
max_new_tokens=2000,
temperature=0.1,
do_sample=True,
use_cache=True,
timeout=300)
try:
response = await llm._acall(user_prompt)
except ClientResponseError as e:
return f"API call failed: {e.message}"
return response
async def submit(message, history, doc_ids, last_image):
# Log the user message and files
print("User Message:", message["text"])
print("User Files:", message["files"])
image = None
image_filetype = None
if message["files"]:
image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
image_filetype = os.path.splitext(image)[1].lower()
#image = resize_image(image)
last_image = (image, image_filetype)
else:
image, image_filetype = last_image
if not image:
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)
human_prompt = message['text']
img_prompt = img_to_prompt([image])
user_prompt = combine_img_with_text(img_prompt, human_prompt)
# Return user input immediately
history.append((human_prompt, "<processing>"))
outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
# Call inference asynchronously
response = await call_inference(user_prompt)
selected_output = response.split("assistant\n")[-1].strip()
# Store the message, image prompt, response, and image file type in MongoDB
document = {
'image_prompt': img_prompt,
'user_prompt': human_prompt,
'response': selected_output,
'image_filetype': image_filetype,
'likes': 0,
'dislikes': 0,
'like_dislike_reason': None
}
result = collection.insert_one(document)
document_id = str(result.inserted_id)
# Log the storage in MongoDB
print(f"Stored in MongoDB with ID: {document_id}")
# Update the chat history and document IDs
history[-1] = (human_prompt, selected_output)
doc_ids.append(document_id)
return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
if not history:
return
index = x.index[0] if isinstance(x.index, list) else x.index
document_id = doc_ids[index]
update_field = "likes" if x.liked else "dislikes"
collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")
def submit_reason_only(doc_ids, reason, selected_index, history):
if selected_index is None:
selected_index = len(history) - 1 # Select the last message if no message is selected
document_id = doc_ids[selected_index]
collection.update_one(
{"_id": ObjectId(document_id)},
{"$set": {"like_dislike_reason": reason}}
)
print(f"Document ID: {document_id}, Reason submitted: {reason}")
return f"Reason submitted."
PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA NeXT 34B-ft-v3 LFX</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is finetuned by LFX</p>
</div>
"""
with gr.Blocks(fill_height=True) as demo:
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Column(scale=1):
image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
submit_reason_btn = gr.Button("Submit Reason", visible=True)
history_state = gr.State([])
doc_ids_state = gr.State([])
last_image_state = gr.State((None, None))
selected_index_state = gr.State(None) # Initializing the state
def select_message(evt: gr.SelectData, history, doc_ids):
selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
print(f"Selected Index: {selected_index}") # Debugging print statement
return gr.update(visible=True), selected_index
chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display])
chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[])
chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state
submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state
demo.queue(api_open=False)
demo.launch(show_api=False, share=True, debug=True) |