File size: 7,944 Bytes
f3f2264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c765b5
f3f2264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ca76d7
f3f2264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import base64
import json
import pymongo
from typing import List, Optional, Dict, Any, Tuple
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from langchain_community.llms import HuggingFaceEndpoint
import gradio as gr
from pymongo import MongoClient
from bson import ObjectId
import asyncio
from PIL import Image, ImageOps
from aiohttp.client_exceptions import ClientResponseError



MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
client = MongoClient(MONGOCONN)
db = client["hf-log"]  # Database name
collection = db["image_tagging_space"]  # Collection name

img_spec_token = "<|im_image|>"
img_join_token = "<|and|>"
sos_token = "<|im_start|>"
eos_token = "<|im_end|>"


# Function to resize image
def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
    img = Image.open(image_path)
    img.thumbnail((max_width, max_height), Image.LANCZOS)
    resized_image_path = f"/tmp/{os.path.basename(image_path)}"
    img.save(resized_image_path)
    return resized_image_path

# Function to encode images to Base64
def encode_image_to_base64(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

# Generate prompt from images using empty tokens
def img_to_prompt(images: List[str]) -> str:
    encoded_images = [encode_image_to_base64(img) for img in images]
    return img_spec_token + img_join_token.join(encoded_images) + img_spec_token

# Combine image and text prompts using empty tokens
def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
    system_prompt = sos_token + f"system\n{ai_role}" + eos_token
    user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
    user_prompt += "assistant\n"
    return system_prompt + user_prompt

def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    return [(user_input, response) for user_input, response in history]

async def call_inference(user_prompt):
    endpoint_url = "https://rn65ru6q35e05iu0.us-east-1.aws.endpoints.huggingface.cloud"
    llm = HuggingFaceEndpoint(endpoint_url=endpoint_url, 
                              max_new_tokens=2000, 
                              temperature=0.1, 
                              do_sample=True,
                              use_cache=True,
                              timeout=300)
    try:
        response = await llm._acall(user_prompt)
    except ClientResponseError as e:
        return f"API call failed: {e.message}"
    return response

async def submit(message, history, doc_ids, last_image):
    # Log the user message and files
    print("User Message:", message["text"])
    print("User Files:", message["files"])
    
    image = None
    image_filetype = None
    if message["files"]:
        image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
        image_filetype = os.path.splitext(image)[1].lower()
        #image = resize_image(image)
        last_image = (image, image_filetype)
    else:
        image, image_filetype = last_image
    
    if not image:
        return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)

    human_prompt = message['text']
    img_prompt = img_to_prompt([image])
    user_prompt = combine_img_with_text(img_prompt, human_prompt)

    # Return user input immediately
    history.append((human_prompt, "<processing>"))
    outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
    
    # Call inference asynchronously
    response = await call_inference(user_prompt)
    selected_output = response.split("assistant\n")[-1].strip()

    # Store the message, image prompt, response, and image file type in MongoDB
    document = {
        'image_prompt': img_prompt,
        'user_prompt': human_prompt,
        'response': selected_output,
        'image_filetype': image_filetype,
        'likes': 0,
        'dislikes': 0,
        'like_dislike_reason': None
    }
    result = collection.insert_one(document)
    document_id = str(result.inserted_id)

    # Log the storage in MongoDB
    print(f"Stored in MongoDB with ID: {document_id}")

    # Update the chat history and document IDs
    history[-1] = (human_prompt, selected_output)
    doc_ids.append(document_id)

    return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)

def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
    if not history:
        return
    index = x.index[0] if isinstance(x.index, list) else x.index
    document_id = doc_ids[index]
    update_field = "likes" if x.liked else "dislikes"
    collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
    print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")

def submit_reason_only(doc_ids, reason, selected_index, history):
    if selected_index is None:
        selected_index = len(history) - 1  # Select the last message if no message is selected
    document_id = doc_ids[selected_index]
    collection.update_one(
        {"_id": ObjectId(document_id)},
        {"$set": {"like_dislike_reason": reason}}
    )
    print(f"Document ID: {document_id}, Reason submitted: {reason}")
    return f"Reason submitted."

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
    <img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
    <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA NeXT 34B-ft-v3 LFX</h1>
    <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is finetuned by LFX</p>
</div>
"""

with gr.Blocks(fill_height=True) as demo:
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
            chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
        with gr.Column(scale=1):
            image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
            reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
            submit_reason_btn = gr.Button("Submit Reason", visible=True)

    history_state = gr.State([])
    doc_ids_state = gr.State([])
    last_image_state = gr.State((None, None))
    selected_index_state = gr.State(None)  # Initializing the state

    def select_message(evt: gr.SelectData, history, doc_ids):
        selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
        print(f"Selected Index: {selected_index}")  # Debugging print statement
        return gr.update(visible=True), selected_index

    chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display])
    chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[])
    chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state])  # Using the state
    submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box])  # Using the state

demo.queue(api_open=False)
demo.launch(show_api=False, share=True, debug=True)