violenceblip / app.py
WhiteWolf21's picture
add
1eaecc8
import argparse
import datetime
import json
import os
import time
import random
import gradio as gr
import requests
import base64
from io import BytesIO
import re
from PIL import Image, ImageDraw
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
from conversation import (default_conversation, conv_templates,
SeparatorStyle)
from constants import LOGDIR
from utils import (build_logger, server_error_msg,
violates_moderation, moderation_msg)
import hashlib
import urllib.request
urllib.request.urlretrieve("https://huggingface.co/Galunid/ShareGPT4V-gguf/resolve/main/mmproj-model-f16.gguf?download=true", "./mmproj-model-f16.gguf")
chat_handler = Llava15ChatHandler(clip_model_path="./mmproj-model-f16.gguf")
llm = Llama.from_pretrained(
repo_id="Galunid/ShareGPT4V-gguf",
filename="ShareGPT4V-f16.gguf",
chat_handler=chat_handler,
verbose=False,
n_ctx=2048, # n_ctx should be increased to accomodate the image embedding
logits_all=True,# needed to make llava work
)
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "Wafer Defect Detection with LLM Classification and Analyze Client"}
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
default_models = []
dropdown_update = gr.Dropdown(
choices=default_models,
value=default_models[0] if len(default_models) > 0 else ""
)
state = default_conversation.copy()
return state, dropdown_update
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
state = default_conversation.copy()
default_models = []
dropdown_update = gr.Dropdown(
choices=default_models,
value=default_models[0] if len(default_models) > 0 else ""
)
return state, dropdown_update
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
if len(state.messages) > 0:
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
no_change_btn,) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if '<image>' not in text:
# text = '<Image><image></Image>' + text
text = text + '\n<image>'
text = (text, image, image_process_mode)
if len(state.get_images(return_pil=True)) > 0:
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def http_bot(state, model_selector, request: gr.Request):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
output = ""
image_base64 = ""
prompt = state.get_prompt()
try:
all_images = state.get_images(return_pil=True)
output_image = None
for image in all_images:
output_image = image.copy()
buffered = BytesIO()
image.save(buffered, format="JPEG")
image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}"
output = llm.create_chat_completion(
max_tokens=1024,
messages = [
{"role": "system", "content": "You are an assistant who perfectly answer all user request."},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_base64 }},
{"type" : "text", "text": f"""{prompt}"""}
]
}
]
)
output = output["choices"][0]["message"]["content"]
print(output)
bboxes = re.findall("\d+\.\d+", output)
print(bboxes)
print(output, state.messages[-1][-1])
for i in range(0, len(bboxes), 4):
width, height = output_image.size
img1 = ImageDraw.Draw(output_image)
img1.rectangle([(float(bboxes[i]) * width, float(bboxes[i+1]) * height), (float(bboxes[i+2]) * width, float(bboxes[i+3]) * height)] , fill ="#ffff33", outline ="red")
text = output
if '<image>' not in text:
# text = '<Image><image></Image>' + text
text = text + '\n<image>'
output = (text, output_image, "Default")
print(output, state.messages[-1][-1])
state.append_message(state.roles[1], output)
# state.messages[-1][-1] = output
except Exception as e:
logger.error(f"{e}")
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
# if output != "":
# if type(state.messages[-1][-1]) is not tuple:
# state.messages[-1][-1] = state.messages[-1][-1][:-1]
# finish_tstamp = time.time()
# logger.info(f"{output}")
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
title_markdown = ("""
# BLIP
""")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode, cur_dir=None):
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
with gr.Blocks(title="BLIP", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
if not embed_mode:
gr.Markdown(title_markdown)
models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"]
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False,
visible=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image", visible=False)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
gr.Examples(examples=[
[f"{cur_dir}/examples/0.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"],
[f"{cur_dir}/examples/1.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"],
[f"{cur_dir}/examples/2.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"],
# [f"{cur_dir}/examples/0.png", "Wafer Defect Type: No-Defect"],
], inputs=[imagebox, textbox])
with gr.Column(scale=7):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="BLIP",
height=940,
layout="panel",
)
with gr.Row():
with gr.Column(scale=7):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
upvote_btn = gr.Button(value="πŸ‘ Upvote")
downvote_btn = gr.Button(value="πŸ‘Ž Downvote")
flag_btn = gr.Button(value="⚠️ Flag")
regenerate_btn = gr.Button(value="πŸ”„ Regenerate")
clear_btn = gr.Button(value="πŸ—‘οΈ Clear")
url_params = gr.JSON(visible=False)
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
queue=False
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
queue=False
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
queue=False
)
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, model_selector],
[state, chatbot] + btn_list,
# concurrency_limit=concurrency_count
queue=False
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, model_selector],
[state, chatbot] + btn_list,
# concurrency_limit=concurrency_count
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False
).then(
http_bot,
# [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr],
[state, model_selector],
[state, chatbot] + btn_list,
# concurrency_limit=concurrency_count
queue=False
)
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[state, model_selector],
_js=get_window_url_params
)
elif args.model_list_mode == "reload":
demo.load(
load_demo_refresh_model_list,
None,
[state, model_selector],
queue=False
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("--host", type=str, default="0.0.0.0")
# parser.add_argument("--port", type=int)
# parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=16)
parser.add_argument("--model-list-mode", type=str, default="reload",
choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
logger.info(args)
demo = build_demo(args.embed)
demo.queue(
api_open=False
).launch(
# server_name=args.host,
# server_port=args.port,
share=args.share
)