conzchunglfxsdu commited on
Commit
18a7a37
·
verified ·
1 Parent(s): 5272254

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +181 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import json
4
+ import pymongo
5
+ from typing import List, Optional, Dict, Any, Tuple
6
+ from PIL import Image
7
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
8
+ from langchain_community.llms import HuggingFaceEndpoint
9
+ import gradio as gr
10
+ from pymongo import MongoClient
11
+ from bson import ObjectId
12
+ import asyncio
13
+ from PIL import Image, ImageOps
14
+ from aiohttp.client_exceptions import ClientResponseError
15
+
16
+
17
+
18
+ MONGOCONN = os.getenv("MONGOCONN", "mongodb://localhost:27017")
19
+ client = MongoClient(MONGOCONN)
20
+ db = client["hf-log"] # Database name
21
+ collection = db["image_tagging_space"] # Collection name
22
+
23
+ img_spec_token = "<|im_image|>"
24
+ img_join_token = "<|and|>"
25
+ sos_token = "[INST]"
26
+ eos_token = "[/INST]"
27
+
28
+
29
+
30
+ ASPECT_RATIOS = [(1, 1), (1, 4), (4, 1)]
31
+ RESOLUTIONS = [(672, 672), (336, 1344), (1344, 336)]
32
+
33
+ # Function to resize image
34
+ def resize_image(image_path: str, max_width: int = 300, max_height: int = 300) -> str:
35
+ img = Image.open(image_path)
36
+ img.thumbnail((max_width, max_height), Image.LANCZOS)
37
+ resized_image_path = f"/tmp/{os.path.basename(image_path)}"
38
+ img.save(resized_image_path)
39
+ return resized_image_path
40
+
41
+ # Function to encode images to Base64
42
+ def encode_image_to_base64(image_path: str) -> str:
43
+ with open(image_path, "rb") as image_file:
44
+ return base64.b64encode(image_file.read()).decode("utf-8")
45
+
46
+ # Generate prompt from images using empty tokens
47
+ def img_to_prompt(images: List[str]) -> str:
48
+ encoded_images = [encode_image_to_base64(img) for img in images]
49
+ return img_spec_token + img_join_token.join(encoded_images) + img_spec_token
50
+
51
+ # Combine image and text prompts using empty tokens
52
+ def combine_img_with_text(img_prompt: str, human_prompt: str, ai_role: str = "Answer questions as a professional designer") -> str:
53
+ system_prompt = sos_token + f"system\n{ai_role}" + eos_token
54
+ user_prompt = sos_token + f"user\n{img_prompt}<image>\n{human_prompt}" + eos_token
55
+ user_prompt += "assistant\n"
56
+ return system_prompt + user_prompt
57
+
58
+ def format_history(history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
59
+ return [(user_input, response) for user_input, response in history]
60
+
61
+ async def call_inference(user_prompt):
62
+ endpoint_url = "https://yzzwmsj8y9ji99i8.us-east-1.aws.endpoints.huggingface.cloud"
63
+ llm = HuggingFaceEndpoint(endpoint_url=endpoint_url,
64
+ max_new_tokens=2000,
65
+ temperature=0.1,
66
+ do_sample=True,
67
+ use_cache=True,
68
+ timeout=300)
69
+ try:
70
+ response = await llm._acall(user_prompt)
71
+ except ClientResponseError as e:
72
+ return f"API call failed: {e.message}"
73
+ return response
74
+
75
+ async def submit(message, history, doc_ids, last_image):
76
+ # Log the user message and files
77
+ print("User Message:", message["text"])
78
+ print("User Files:", message["files"])
79
+
80
+ image = None
81
+ image_filetype = None
82
+ if message["files"]:
83
+ image = message["files"][-1]["path"] if isinstance(message["files"][-1], dict) else message["files"][-1]
84
+ image_filetype = os.path.splitext(image)[1].lower()
85
+ image = resize_image(image)
86
+ last_image = (image, image_filetype)
87
+ else:
88
+ image, image_filetype = last_image
89
+
90
+ if not image:
91
+ return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=None)
92
+
93
+ human_prompt = message['text']
94
+ img_prompt = img_to_prompt([image])
95
+ user_prompt = combine_img_with_text(img_prompt, human_prompt)
96
+
97
+ # Return user input immediately
98
+ history.append((human_prompt, "<processing>"))
99
+ outputs = format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
100
+
101
+ # Call inference asynchronously
102
+ response = await call_inference(user_prompt)
103
+ selected_output = response.split("assistant\n")[-1].strip()
104
+
105
+ # Store the message, image prompt, response, and image file type in MongoDB
106
+ document = {
107
+ 'image_prompt': img_prompt,
108
+ 'user_prompt': human_prompt,
109
+ 'response': selected_output,
110
+ 'image_filetype': image_filetype,
111
+ 'likes': 0,
112
+ 'dislikes': 0,
113
+ 'like_dislike_reason': None
114
+ }
115
+ result = collection.insert_one(document)
116
+ document_id = str(result.inserted_id)
117
+
118
+ # Log the storage in MongoDB
119
+ print(f"Stored in MongoDB with ID: {document_id}")
120
+
121
+ # Update the chat history and document IDs
122
+ history[-1] = (human_prompt, selected_output)
123
+ doc_ids.append(document_id)
124
+
125
+ return format_history(history), gr.Textbox(value=None, interactive=True), doc_ids, last_image, gr.Image(value=image, show_label=False)
126
+
127
+ def print_like_dislike(x: gr.LikeData, history, doc_ids, reason):
128
+ if not history:
129
+ return
130
+ index = x.index[0] if isinstance(x.index, list) else x.index
131
+ document_id = doc_ids[index]
132
+ update_field = "likes" if x.liked else "dislikes"
133
+ collection.update_one({"_id": ObjectId(document_id)}, {"$inc": {update_field: 1}, "$set": {"like_dislike_reason": reason}})
134
+ print(f"Document ID: {document_id}, Liked: {x.liked}, Reason: {reason}")
135
+
136
+ def submit_reason_only(doc_ids, reason, selected_index, history):
137
+ if selected_index is None:
138
+ selected_index = len(history) - 1 # Select the last message if no message is selected
139
+ document_id = doc_ids[selected_index]
140
+ collection.update_one(
141
+ {"_id": ObjectId(document_id)},
142
+ {"$set": {"like_dislike_reason": reason}}
143
+ )
144
+ print(f"Document ID: {document_id}, Reason submitted: {reason}")
145
+ return f"Reason submitted."
146
+
147
+ PLACEHOLDER = """
148
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
149
+ <img src="https://lfxdigital.com/wp-content/uploads/2021/02/LFX_Logo_Final-01.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;">
150
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-NeXT-Mistral-7B-LFX</h1>
151
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">This multimodal LLM is hosted by LFX</p>
152
+ </div>
153
+ """
154
+
155
+ with gr.Blocks(fill_height=True) as demo:
156
+ with gr.Row():
157
+ with gr.Column(scale=3):
158
+ chatbot = gr.Chatbot(placeholder=PLACEHOLDER, scale=1, height=600)
159
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
160
+ with gr.Column(scale=1):
161
+ image_display = gr.Image(type="filepath", interactive=False, show_label=False, height=400)
162
+ reason_box = gr.Textbox(label="Reason for Like/Dislike (optional). Click a chat message to specify, or the latest message will be used.", visible=True)
163
+ submit_reason_btn = gr.Button("Submit Reason", visible=True)
164
+
165
+ history_state = gr.State([])
166
+ doc_ids_state = gr.State([])
167
+ last_image_state = gr.State((None, None))
168
+ selected_index_state = gr.State(None) # Initializing the state
169
+
170
+ def select_message(evt: gr.SelectData, history, doc_ids):
171
+ selected_index = evt.index if isinstance(evt.index, int) else evt.index[0]
172
+ print(f"Selected Index: {selected_index}") # Debugging print statement
173
+ return gr.update(visible=True), selected_index
174
+
175
+ chat_msg = chat_input.submit(submit, inputs=[chat_input, history_state, doc_ids_state, last_image_state], outputs=[chatbot, chat_input, doc_ids_state, last_image_state, image_display])
176
+ chatbot.like(print_like_dislike, inputs=[history_state, doc_ids_state, reason_box], outputs=[])
177
+ chatbot.select(select_message, inputs=[history_state, doc_ids_state], outputs=[reason_box, selected_index_state]) # Using the state
178
+ submit_reason_btn.click(submit_reason_only, inputs=[doc_ids_state, reason_box, selected_index_state, history_state], outputs=[reason_box]) # Using the state
179
+
180
+ demo.queue(api_open=False)
181
+ demo.launch(show_api=False, share=True, debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ langchain
3
+ spaces
4
+ pillow
5
+ accelerate
6
+ transformers
7
+ gradio
8
+ pymongo