Spaces:
Sleeping
Sleeping
import argparse | |
import datetime | |
import json | |
import os | |
import time | |
import random | |
import gradio as gr | |
import requests | |
import base64 | |
from io import BytesIO | |
import re | |
from PIL import Image, ImageDraw | |
from llama_cpp import Llama | |
from llama_cpp.llama_chat_format import Llava15ChatHandler | |
from conversation import (default_conversation, conv_templates, | |
SeparatorStyle) | |
from constants import LOGDIR | |
from utils import (build_logger, server_error_msg, | |
violates_moderation, moderation_msg) | |
import hashlib | |
import urllib.request | |
urllib.request.urlretrieve("https://huggingface.co/Galunid/ShareGPT4V-gguf/resolve/main/mmproj-model-f16.gguf?download=true", "./mmproj-model-f16.gguf") | |
chat_handler = Llava15ChatHandler(clip_model_path="./mmproj-model-f16.gguf") | |
llm = Llama.from_pretrained( | |
repo_id="Galunid/ShareGPT4V-gguf", | |
filename="ShareGPT4V-f16.gguf", | |
chat_handler=chat_handler, | |
verbose=False, | |
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding | |
logits_all=True,# needed to make llava work | |
) | |
logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
headers = {"User-Agent": "Wafer Defect Detection with LLM Classification and Analyze Client"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
priority = { | |
"vicuna-13b": "aaaaaaa", | |
"koala-13b": "aaaaaab", | |
} | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
return name | |
get_window_url_params = """ | |
function() { | |
const params = new URLSearchParams(window.location.search); | |
url_params = Object.fromEntries(params); | |
console.log(url_params); | |
return url_params; | |
} | |
""" | |
def load_demo(url_params, request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") | |
default_models = [] | |
dropdown_update = gr.Dropdown( | |
choices=default_models, | |
value=default_models[0] if len(default_models) > 0 else "" | |
) | |
state = default_conversation.copy() | |
return state, dropdown_update | |
def load_demo_refresh_model_list(request: gr.Request): | |
logger.info(f"load_demo. ip: {request.client.host}") | |
state = default_conversation.copy() | |
default_models = [] | |
dropdown_update = gr.Dropdown( | |
choices=default_models, | |
value=default_models[0] if len(default_models) > 0 else "" | |
) | |
return state, dropdown_update | |
def vote_last_response(state, vote_type, model_selector, request: gr.Request): | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"tstamp": round(time.time(), 4), | |
"type": vote_type, | |
"model": model_selector, | |
"state": state.dict(), | |
"ip": request.client.host, | |
} | |
fout.write(json.dumps(data) + "\n") | |
def upvote_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"upvote. ip: {request.client.host}") | |
vote_last_response(state, "upvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"downvote. ip: {request.client.host}") | |
vote_last_response(state, "downvote", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response(state, model_selector, request: gr.Request): | |
logger.info(f"flag. ip: {request.client.host}") | |
vote_last_response(state, "flag", model_selector, request) | |
return ("",) + (disable_btn,) * 3 | |
def regenerate(state, image_process_mode, request: gr.Request): | |
logger.info(f"regenerate. ip: {request.client.host}") | |
if len(state.messages) > 0: | |
state.messages[-1][-1] = None | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def clear_history(request: gr.Request): | |
logger.info(f"clear_history. ip: {request.client.host}") | |
state = default_conversation.copy() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def add_text(state, text, image, image_process_mode, request: gr.Request): | |
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
if len(text) <= 0 and image is None: | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 | |
if args.moderate: | |
flagged = violates_moderation(text) | |
if flagged: | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( | |
no_change_btn,) * 5 | |
text = text[:1536] # Hard cut-off | |
if image is not None: | |
text = text[:1200] # Hard cut-off for images | |
if '<image>' not in text: | |
# text = '<Image><image></Image>' + text | |
text = text + '\n<image>' | |
text = (text, image, image_process_mode) | |
if len(state.get_images(return_pil=True)) > 0: | |
state = default_conversation.copy() | |
state.append_message(state.roles[0], text) | |
state.append_message(state.roles[1], None) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def http_bot(state, model_selector, request: gr.Request): | |
logger.info(f"http_bot. ip: {request.client.host}") | |
start_tstamp = time.time() | |
model_name = model_selector | |
output = "" | |
image_base64 = "" | |
prompt = state.get_prompt() | |
try: | |
all_images = state.get_images(return_pil=True) | |
output_image = None | |
for image in all_images: | |
output_image = image.copy() | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" | |
output = llm.create_chat_completion( | |
max_tokens=1024, | |
messages = [ | |
{"role": "system", "content": "You are an assistant who perfectly answer all user request."}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image_url", "image_url": {"url": image_base64 }}, | |
{"type" : "text", "text": f"""{prompt}"""} | |
] | |
} | |
] | |
) | |
output = output["choices"][0]["message"]["content"] | |
print(output) | |
bboxes = re.findall("\d+\.\d+", output) | |
print(bboxes) | |
print(output, state.messages[-1][-1]) | |
for i in range(0, len(bboxes), 4): | |
width, height = output_image.size | |
img1 = ImageDraw.Draw(output_image) | |
img1.rectangle([(float(bboxes[i]) * width, float(bboxes[i+1]) * height), (float(bboxes[i+2]) * width, float(bboxes[i+3]) * height)] , fill ="#ffff33", outline ="red") | |
text = output | |
if '<image>' not in text: | |
# text = '<Image><image></Image>' + text | |
text = text + '\n<image>' | |
output = (text, output_image, "Default") | |
print(output, state.messages[-1][-1]) | |
state.append_message(state.roles[1], output) | |
# state.messages[-1][-1] = output | |
except Exception as e: | |
logger.error(f"{e}") | |
state.messages[-1][-1] = server_error_msg | |
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
return | |
# if output != "": | |
# if type(state.messages[-1][-1]) is not tuple: | |
# state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
# finish_tstamp = time.time() | |
# logger.info(f"{output}") | |
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 | |
title_markdown = (""" | |
# BLIP | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
def build_demo(embed_mode, cur_dir=None): | |
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
with gr.Blocks(title="BLIP", theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
if not embed_mode: | |
gr.Markdown(title_markdown) | |
models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"] | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=models, | |
value=models[0] if len(models) > 0 else "", | |
interactive=True, | |
show_label=False, | |
container=False, | |
visible=False) | |
imagebox = gr.Image(type="pil") | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
if cur_dir is None: | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
gr.Examples(examples=[ | |
[f"{cur_dir}/examples/0.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], | |
[f"{cur_dir}/examples/1.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], | |
[f"{cur_dir}/examples/2.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], | |
# [f"{cur_dir}/examples/0.png", "Wafer Defect Type: No-Defect"], | |
], inputs=[imagebox, textbox]) | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="BLIP", | |
height=940, | |
layout="panel", | |
) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="π Upvote") | |
downvote_btn = gr.Button(value="π Downvote") | |
flag_btn = gr.Button(value="β οΈ Flag") | |
regenerate_btn = gr.Button(value="π Regenerate") | |
clear_btn = gr.Button(value="ποΈ Clear") | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
queue=False | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
queue=False | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn], | |
queue=False | |
) | |
regenerate_btn.click( | |
regenerate, | |
[state, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], | |
[state, model_selector], | |
[state, chatbot] + btn_list, | |
# concurrency_limit=concurrency_count | |
queue=False | |
) | |
clear_btn.click( | |
clear_history, | |
None, | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
) | |
textbox.submit( | |
add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], | |
[state, model_selector], | |
[state, chatbot] + btn_list, | |
# concurrency_limit=concurrency_count | |
) | |
submit_btn.click( | |
add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
).then( | |
http_bot, | |
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], | |
[state, model_selector], | |
[state, chatbot] + btn_list, | |
# concurrency_limit=concurrency_count | |
queue=False | |
) | |
if args.model_list_mode == "once": | |
demo.load( | |
load_demo, | |
[url_params], | |
[state, model_selector], | |
_js=get_window_url_params | |
) | |
elif args.model_list_mode == "reload": | |
demo.load( | |
load_demo_refresh_model_list, | |
None, | |
[state, model_selector], | |
queue=False | |
) | |
else: | |
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("--host", type=str, default="0.0.0.0") | |
# parser.add_argument("--port", type=int) | |
# parser.add_argument("--controller-url", type=str, default="http://localhost:21001") | |
parser.add_argument("--concurrency-count", type=int, default=16) | |
parser.add_argument("--model-list-mode", type=str, default="reload", | |
choices=["once", "reload"]) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--moderate", action="store_true") | |
parser.add_argument("--embed", action="store_true") | |
args = parser.parse_args() | |
logger.info(f"args: {args}") | |
logger.info(args) | |
demo = build_demo(args.embed) | |
demo.queue( | |
api_open=False | |
).launch( | |
# server_name=args.host, | |
# server_port=args.port, | |
share=args.share | |
) |