llava-4bit / app.py
whan12's picture
Update app.py
239c1b0 verified
raw
history blame
5.61 kB
import os
import string
import copy
import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re
import time
import random
DESCRIPTION = "# LLaVA πŸŒ‹πŸ’ͺ - Now with Arnold Mode!"
model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
def extract_response_pairs(text):
turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
turns = [turn.strip() for turn in turns if turn.strip()]
conv_list = []
for i in range(0, len(turns[1::2]), 2):
if i + 1 < len(turns[1::2]):
conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
return conv_list
def add_text(history, text):
history = history + [[text, None]]
return history, "" # Clear the input field after submission
def arnold_speak(text):
arnold_phrases = [
"Come with me if you want to lift!",
"I'll be back... after my protein shake.",
"Hasta la vista, baby weight!",
"Get to da choppa... I mean, da squat rack!",
"You lack discipline! But don't worry, I'm here to pump you up!"
]
text = text.replace(".", "!") # More enthusiastic punctuation
text = text.replace("gym", "iron paradise")
text = text.replace("exercise", "pump iron")
text = text.replace("workout", "sculpt your physique")
# Add random Arnold phrase to the end
text += " " + random.choice(arnold_phrases)
return text
def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
try:
outputs = pipe(images=image, prompt=prompt,
generate_kwargs={"temperature": temperature,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"max_length": max_length,
"min_length": min_length,
"top_p": top_p})
inference_output = outputs[0]["generated_text"]
return inference_output
except Exception as e:
print(f"Error during inference: {str(e)}")
return f"An error occurred during inference: {str(e)}"
def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
if text_input == "":
yield history + [["Please input text", None]]
return
if image is None:
yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]]
return
chat_history = " ".join([item for sublist in history for item in sublist if item is not None]) # Flatten history
system_prompt = "You are a helpful AI assistant. " if not arnold_mode else "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation."
prompt = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
if arnold_mode:
response = arnold_speak(response)
history.append([text_input, ""])
for i in range(len(response)):
history[-1][1] = response[:i+1]
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
chatbot = gr.Chatbot()
with gr.Row():
image = gr.Image(type="pil")
with gr.Column():
text_input = gr.Textbox(label="Chat Input", lines=3)
arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1)
length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5)
max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1)
min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1)
top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1)
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Submit", variant="primary")
submit_button.click(
fn=bot,
inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode],
outputs=chatbot
).then(
fn=lambda: "",
outputs=text_input
)
clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False)
examples = [
["./examples/baklava.png", "How to make this pastry?"],
["./examples/bee.png", "Describe this image."]
]
gr.Examples(examples=examples, inputs=[image, text_input])
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)