import os import string import copy import gradio as gr import PIL.Image import torch from transformers import BitsAndBytesConfig, pipeline import re import time import random DESCRIPTION = "# LLaVA 🌋💪 - Now with Arnold Mode!" model_id = "llava-hf/llava-1.5-7b-hf" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 ) pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) def extract_response_pairs(text): turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] turns = [turn.strip() for turn in turns if turn.strip()] conv_list = [] for i in range(0, len(turns[1::2]), 2): if i + 1 < len(turns[1::2]): conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) return conv_list def add_text(history, text): history = history + [[text, None]] return history, "" # Clear the input field after submission def arnold_speak(text): arnold_phrases = [ "Come with me if you want to lift!", "I'll be back... after my protein shake.", "Hasta la vista, baby weight!", "Get to da choppa... I mean, da squat rack!", "You lack discipline! But don't worry, I'm here to pump you up!" ] text = text.replace(".", "!") # More enthusiastic punctuation text = text.replace("gym", "iron paradise") text = text.replace("exercise", "pump iron") text = text.replace("workout", "sculpt your physique") # Add random Arnold phrase to the end text += " " + random.choice(arnold_phrases) return text def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): try: outputs = pipe(images=image, prompt=prompt, generate_kwargs={"temperature": temperature, "length_penalty": length_penalty, "repetition_penalty": repetition_penalty, "max_length": max_length, "min_length": min_length, "top_p": top_p}) inference_output = outputs[0]["generated_text"] return inference_output except Exception as e: print(f"Error during inference: {str(e)}") return f"An error occurred during inference: {str(e)}" def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode): if text_input == "": yield history + [["Please input text", None]] return if image is None: yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]] return chat_history = " ".join([item for sublist in history for item in sublist if item is not None]) # Flatten history system_prompt = "You are a helpful AI assistant. " if not arnold_mode else "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation." prompt = f"{system_prompt}\n{chat_history}\nUSER: \n{text_input}\nASSISTANT:" response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) if arnold_mode: response = arnold_speak(response) history.append([text_input, ""]) for i in range(len(response)): history[-1][1] = response[:i+1] time.sleep(0.05) yield history with gr.Blocks() as demo: gr.Markdown(DESCRIPTION) gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚡️ See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""") chatbot = gr.Chatbot() with gr.Row(): image = gr.Image(type="pil") with gr.Column(): text_input = gr.Textbox(label="Chat Input", lines=3) arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode") with gr.Accordion(label="Advanced settings", open=False): temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1) length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2) repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5) max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1) min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1) top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1) with gr.Row(): clear_button = gr.Button("Clear") submit_button = gr.Button("Submit", variant="primary") submit_button.click( fn=bot, inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode], outputs=chatbot ).then( fn=lambda: "", outputs=text_input ) clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False) examples = [ ["./examples/baklava.png", "How to make this pastry?"], ["./examples/bee.png", "Describe this image."] ] gr.Examples(examples=examples, inputs=[image, text_input]) if __name__ == "__main__": demo.queue(max_size=10).launch(debug=True)