import os import base64 import json import pymongo from typing import List, Optional, Dict, Any, Tuple from PIL import Image from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration from langchain_community.llms import HuggingFaceEndpoint import gradio as gr from pymongo import MongoClient from bson import ObjectId import asyncio from PIL import Image, ImageOps from aiohttp.client_exceptions import ClientResponseError MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017") client = MongoClient(MONGOCONN) db = client["hf-log"] # Database name collection = db["image_tagging_space"] # Collection name img_spec_token = "<|im_image|>" img_join_token = "<|and|>" sos_token = "<|im_start|>" eos_token = "<|im_end|>" # Function to resize image def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str: img = Image.open(image_path) img.thumbnail((max_width, max_height), Image.LANCZOS) resized_image_path = f"/tmp/{os.path.basename(image_path)}" img.save(resized_image_path) return resized_image_path # Function to encode images to Base64 def encode_image_to_base64(image_path: str) -> str: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") # Generate prompt from images using empty tokens def img_to_prompt(images: List[str]) -> str: encoded_images = [encode_image_to_base64(img) for img in images] return img_spec_token + img_join_token.join(encoded_images) + img_spec_token # Combine image and text prompts using empty tokens def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str: system_prompt = sos_token + f"system\n{ai_role}" + eos_token user_prompt = sos_token + f"user\n{img_prompt}\n{human_prompt}" + eos_token user_prompt += "assistant\n" return system_prompt + user_prompt def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]: return [(user_input, response) for user_input, response in history] async def call_inference(user_prompt): endpoint_url = "https://rn65ru6q35e05iu0.us-east-1.aws.endpoints.huggingface.cloud" llm = HuggingFaceEndpoint(endpoint_url=endpoint_url, max_new_tokens=2000, temperature=0.1, do_sample=True, use_cache=True, timeout=300) try: response = await llm._acall(user_prompt) except ClientResponseError as e: return f"API call failed: {e.message}" return response async def submit(message, history, doc_ids, last_image): # Log the user message and files print("User Message:", message["text"]) print("User Files:", message["files"]) image = None image_filetype = None if message["files"]: image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1] image_filetype = os.path.splitext(image)[1].lower() #image = resize_image(image) last_image = (image, image_filetype) else: image, image_filetype = last_image if not image: return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None) human_prompt = message['text'] img_prompt = img_to_prompt([image]) user_prompt = combine_img_with_text(img_prompt, human_prompt) # Return user input immediately history.append((human_prompt, "")) outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False) # Call inference asynchronously response = await call_inference(user_prompt) selected_output = response.split("assistant\n")[-1].strip() # Store the message, image prompt, response, and image file type in MongoDB document = { 'image_prompt': img_prompt, 'user_prompt': human_prompt, 'response': selected_output, 'image_filetype': image_filetype, 'likes': 0, 'dislikes': 0, 'like_dislike_reason': None } result = collection.insert_one(document) document_id = str(result.inserted_id) # Log the storage in MongoDB print(f"Stored in MongoDB with ID: {document_id}") # Update the chat history and document IDs history[-1] = (human_prompt, selected_output) doc_ids.append(document_id) return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False) def print_like_dislike(x: gr.LikeData, history, doc_ids, reason): if not history: return index = x.index[0] if isinstance(x.index, list) else x.index document_id = doc_ids[index] update_field = "likes" if x.liked else "dislikes" collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}}) print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}") def submit_reason_only(doc_ids, reason, selected_index, history): if selected_index is None: selected_index = len(history) - 1 # Select the last message if no message is selected document_id = doc_ids[selected_index] collection.update_one( {"_id": ObjectId(document_id)}, {"$set": {"like_dislike_reason": reason}} ) print(f"Document ID: {document_id}, Reason submitted: {reason}") return f"Reason submitted." PLACEHOLDER = """

LLaVA NeXT 34B-ft-v3 LFX

This multimodal LLM is finetuned by LFX

""" with gr.Blocks(fill_height=True) as demo: with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Column(scale=1): image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400) reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True) submit_reason_btn = gr.Button("Submit Reason", visible=True) history_state = gr.State([]) doc_ids_state = gr.State([]) last_image_state = gr.State((None, None)) selected_index_state = gr.State(None) # Initializing the state def select_message(evt: gr.SelectData, history, doc_ids): selected_index = evt.index if isinstance(evt.index, int) else evt.index[0] print(f"Selected Index: {selected_index}") # Debugging print statement return gr.update(visible=True), selected_index chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display]) chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[]) chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state demo.queue(api_open=False) demo.launch(show_api=False, share=True, debug=True)