import argparse import datetime import json import os import time import random import gradio as gr import requests import base64 from io import BytesIO import re from PIL import Image, ImageDraw from llama_cpp import Llama from llama_cpp.llama_chat_format import Llava15ChatHandler from conversation import (default_conversation, conv_templates, SeparatorStyle) from constants import LOGDIR from utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) import hashlib import urllib.request urllib.request.urlretrieve("https://huggingface.co/Galunid/ShareGPT4V-gguf/resolve/main/mmproj-model-f16.gguf?download=true", "./mmproj-model-f16.gguf") chat_handler = Llava15ChatHandler(clip_model_path="./mmproj-model-f16.gguf") llm = Llama.from_pretrained( repo_id="Galunid/ShareGPT4V-gguf", filename="ShareGPT4V-f16.gguf", chat_handler=chat_handler, verbose=False, n_ctx=2048, # n_ctx should be increased to accomodate the image embedding logits_all=True,# needed to make llava work ) logger = build_logger("gradio_web_server", "gradio_web_server.log") headers = {"User-Agent": "Wafer Defect Detection with LLM Classification and Analyze Client"} no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) priority = { "vicuna-13b": "aaaaaaa", "koala-13b": "aaaaaab", } def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") default_models = [] dropdown_update = gr.Dropdown( choices=default_models, value=default_models[0] if len(default_models) > 0 else "" ) state = default_conversation.copy() return state, dropdown_update def load_demo_refresh_model_list(request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}") state = default_conversation.copy() default_models = [] dropdown_update = gr.Dropdown( choices=default_models, value=default_models[0] if len(default_models) > 0 else "" ) return state, dropdown_update def vote_last_response(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, model_selector, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, model_selector, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, model_selector, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") if len(state.messages) > 0: state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def add_text(state, text, image, image_process_mode, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0 and image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 if args.moderate: flagged = violates_moderation(text) if flagged: state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( no_change_btn,) * 5 text = text[:1536] # Hard cut-off if image is not None: text = text[:1200] # Hard cut-off for images if '' not in text: # text = '' + text text = text + '\n' text = (text, image, image_process_mode) if len(state.get_images(return_pil=True)) > 0: state = default_conversation.copy() state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def http_bot(state, model_selector, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() model_name = model_selector output = "" image_base64 = "" prompt = state.get_prompt() try: all_images = state.get_images(return_pil=True) output_image = None for image in all_images: output_image = image.copy() buffered = BytesIO() image.save(buffered, format="JPEG") image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" output = llm.create_chat_completion( max_tokens=1024, messages = [ {"role": "system", "content": "You are an assistant who perfectly answer all user request."}, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": image_base64 }}, {"type" : "text", "text": f"""{prompt}"""} ] } ] ) output = output["choices"][0]["message"]["content"] print(output) bboxes = re.findall("\d+\.\d+", output) print(bboxes) print(output, state.messages[-1][-1]) for i in range(0, len(bboxes), 4): width, height = output_image.size img1 = ImageDraw.Draw(output_image) img1.rectangle([(float(bboxes[i]) * width, float(bboxes[i+1]) * height), (float(bboxes[i+2]) * width, float(bboxes[i+3]) * height)] , fill ="#ffff33", outline ="red") text = output if '' not in text: # text = '' + text text = text + '\n' output = (text, output_image, "Default") print(output, state.messages[-1][-1]) state.append_message(state.roles[1], output) # state.messages[-1][-1] = output except Exception as e: logger.error(f"{e}") state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return # if output != "": # if type(state.messages[-1][-1]) is not tuple: # state.messages[-1][-1] = state.messages[-1][-1][:-1] # finish_tstamp = time.time() # logger.info(f"{output}") yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 title_markdown = (""" # BLIP """) block_css = """ #buttons button { min-width: min(120px,100%); } """ def build_demo(embed_mode, cur_dir=None): textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) with gr.Blocks(title="BLIP", theme=gr.themes.Default(), css=block_css) as demo: state = gr.State() if not embed_mode: gr.Markdown(title_markdown) models = ["Propose Solution", "Baseline 1", "Baseline 2", "Baseline 3"] with gr.Row(): with gr.Column(scale=3): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=models, value=models[0] if len(models) > 0 else "", interactive=True, show_label=False, container=False, visible=False) imagebox = gr.Image(type="pil") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", label="Preprocess for non-square image", visible=False) if cur_dir is None: cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples(examples=[ [f"{cur_dir}/examples/0.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], [f"{cur_dir}/examples/1.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], [f"{cur_dir}/examples/2.jpg", "What are the violence acts and give me the coordinates (x,y,w,h) to draw bounding box in the image?"], # [f"{cur_dir}/examples/0.png", "Wafer Defect Type: No-Defect"], ], inputs=[imagebox, textbox]) with gr.Column(scale=7): chatbot = gr.Chatbot( elem_id="chatbot", label="BLIP", height=940, layout="panel", ) with gr.Row(): with gr.Column(scale=7): textbox.render() with gr.Column(scale=1, min_width=50): submit_btn = gr.Button(value="Send", variant="primary") with gr.Row(elem_id="buttons") as button_row: upvote_btn = gr.Button(value="👍 Upvote") downvote_btn = gr.Button(value="👎 Downvote") flag_btn = gr.Button(value="⚠️ Flag") regenerate_btn = gr.Button(value="🔄 Regenerate") clear_btn = gr.Button(value="🗑️ Clear") url_params = gr.JSON(visible=False) # Register listeners btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] upvote_btn.click( upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False ) downvote_btn.click( downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False ) flag_btn.click( flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False ) regenerate_btn.click( regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( http_bot, # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], [state, model_selector], [state, chatbot] + btn_list, # concurrency_limit=concurrency_count queue=False ) clear_btn.click( clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False ) textbox.submit( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( http_bot, # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], [state, model_selector], [state, chatbot] + btn_list, # concurrency_limit=concurrency_count ) submit_btn.click( add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list, queue=False ).then( http_bot, # [state, model_selector, temperature, top_p, max_output_tokens, gen_image, use_ocr], [state, model_selector], [state, chatbot] + btn_list, # concurrency_limit=concurrency_count queue=False ) if args.model_list_mode == "once": demo.load( load_demo, [url_params], [state, model_selector], _js=get_window_url_params ) elif args.model_list_mode == "reload": demo.load( load_demo_refresh_model_list, None, [state, model_selector], queue=False ) else: raise ValueError(f"Unknown model list mode: {args.model_list_mode}") return demo if __name__ == "__main__": parser = argparse.ArgumentParser() # parser.add_argument("--host", type=str, default="0.0.0.0") # parser.add_argument("--port", type=int) # parser.add_argument("--controller-url", type=str, default="http://localhost:21001") parser.add_argument("--concurrency-count", type=int, default=16) parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") args = parser.parse_args() logger.info(f"args: {args}") logger.info(args) demo = build_demo(args.embed) demo.queue( api_open=False ).launch( # server_name=args.host, # server_port=args.port, share=args.share )