Keyven's picture
Update handle text and generate
c57f245
raw
history blame
8.63 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import re
import copy
import secrets
from pathlib import Path
# Constants
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="auto", trust_remote_code=True).eval()
def format_text(text):
"""Format text for rendering in the chat UI."""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def get_chat_response(chatbot, task_history):
global model, tokenizer
chat_query = chatbot[-1][0]
query = task_history[-1][0]
history_cp = copy.deepcopy(task_history)
full_response = ""
history_filter = []
pic_idx = 1
pre = ""
for i, (q, a) in enumerate(history_cp):
if isinstance(q, (tuple, list)):
q = f'Picture {pic_idx}: <img>{q[0]}</img>'
pre += q + '\n'
pic_idx += 1
else:
pre += q
history_filter.append((pre, a))
pre = ""
history, message = history_filter[:-1], history_filter[-1][0]
response, history = model.chat(tokenizer, message, history=history)
image = tokenizer.draw_bbox_on_latest_picture(response, history)
if image is not None:
temp_dir = secrets.token_hex(20)
temp_dir = Path("/tmp") / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = f"tmp{secrets.token_hex(5)}.jpg"
filename = temp_dir / name
image.save(str(filename))
chatbot[-1] = (format_text(chat_query), (str(filename),)) # Hier verwenden wir format_text statt _parse_text
chat_response = response.replace("<ref>", "")
chat_response = chat_response.replace(r"</ref>", "")
chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
if chat_response != "":
chatbot.append((None, chat_response))
else:
chatbot[-1] = (format_text(chat_query), response)
full_response = format_text(response)
task_history[-1] = (query, full_response)
return chatbot
def handle_text_input(history, task_history, text):
"""Handle text input from the user."""
# Überprüfen, ob das Eingabefeld leer ist
if not text:
# Wenn das Eingabefeld leer ist, senden Sie eine vordefinierte Anfrage
text = "Describe the image for me..."
# Aktualisieren Sie das Chat- und Task-Verlauf mit der vordefinierten Anfrage
history = history + [(format_text(text), None)]
task_history = task_history + [(text, None)]
# Rufen Sie get_chat_response auf, um eine Antwort zu generieren
return get_chat_response(history, task_history)
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(format_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def handle_file_upload(history, task_history, file):
"""Handle file upload from the user."""
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def clear_input():
"""Clear the user input."""
return gr.update(value="")
def clear_history(task_history):
"""Clear the chat history."""
task_history.clear()
return []
def handle_regeneration(chatbot, task_history, input_field):
"""Handle the regeneration of the last response."""
print("Regenerate clicked")
print("Before:", task_history, chatbot)
if not task_history:
return chatbot
# Überprüfen, ob das Eingabefeld leer ist
if not input_field.value:
# Wenn das Eingabefeld leer ist, senden Sie eine vordefinierte Anfrage
predefined_query = "Describe this image for me..."
# Aktualisieren Sie das Eingabefeld mit der vordefinierten Anfrage
input_field.update(value=predefined_query)
# Führen Sie die normale Texteingabebehandlung durch
handle_text_input(chatbot, task_history, predefined_query)
else:
item = task_history[-1]
if item[1] is None:
return chatbot
task_history[-1] = (item[0], None)
chatbot_item = chatbot.pop(-1)
if chatbot_item[0] is None:
chatbot[-1] = (chatbot[-1][0], None)
else:
chatbot.append((chatbot_item[0], None))
print("After:", task_history, chatbot)
return get_chat_response(chatbot, task_history)
# Custom CSS
css = '''
@import url('https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0/css/bootstrap.min.css');
.gradio-button, .gradio-upload-button {
border: none;
border-radius: 4px;
cursor: pointer;
font-size: 16px;
margin: 2px;
}
.gradio-button {
background-color: #008CBA;
color: white;
}
.gradio-button:hover {
background-color: #005f5f;
}
.gradio-upload-button input {
display: none;
}
.gradio-upload-button {
background-color: #008CBA;
color: white;
padding: 10px 20px;
}
.gradio-upload-button:hover {
background-color: #005f5f;
}
.control-width {
width: 100%;
}
'''
with gr.Blocks(css=css) as demo:
gr.Markdown("# Qwen-VL Multimodal-Vision-Insight")
gr.Markdown(
"## Developed by Keyvan Hardani (Keyvven on [Twitter](https://twitter.com/Keyvven))\n"
"Special thanks to [@Artificialguybr](https://twitter.com/artificialguybr) for the inspiration from his code.\n"
"### Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud\n"
)
chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
with gr.Column(width=4):
upload_btn = gr.UploadButton("📁 Upload", file_types=["image"], elem_classes="control-width")
with gr.Column(width=2):
submit_btn = gr.Button("🚀 Submit", elem_classes="control-width")
with gr.Column(width=2):
regen_btn = gr.Button("🤔️ Regenerate", elem_classes="control-width")
with gr.Column(width=2):
clear_btn = gr.Button("🧹 Clear History", elem_classes="control-width")
gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
submit_btn.click(handle_text_input, [chatbot, task_history, query], [chatbot, task_history]).then(
get_chat_response, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(clear_input, [], [query])
clear_btn.click(clear_history, [task_history], [chatbot], show_progress=True)
regen_btn.click(handle_regeneration, [chatbot, task_history, query], [chatbot], show_progress=True)
upload_btn.upload(handle_file_upload, [chatbot, task_history, upload_btn], [chatbot, task_history], show_progress=True)
demo.launch()