File size: 5,613 Bytes
8400add
 
95dbe7e
8400add
 
 
78a718b
8400add
95dbe7e
78a718b
8400add
78a718b
e5327ee
78a718b
 
 
 
 
 
11e466e
8400add
458ccb5
 
 
 
 
95dbe7e
458ccb5
8400add
8ed6e93
78a718b
239c1b0
78a718b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239c1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a93d4
239c1b0
 
 
78a718b
239c1b0
 
78a718b
239c1b0
78a718b
239c1b0
78a718b
239c1b0
c2a93d4
239c1b0
78a718b
 
 
c2a93d4
239c1b0
 
 
c2a93d4
239c1b0
8400add
239c1b0
8400add
a3a174a
 
239c1b0
 
 
8400add
78a718b
239c1b0
 
 
 
8400add
239c1b0
 
 
 
 
 
8400add
95dbe7e
239c1b0
 
 
 
 
 
 
 
 
 
8400add
11e466e
239c1b0
11e466e
78a718b
 
 
 
 
8400add
 
66716db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import string
import copy
import gradio as gr
import PIL.Image
import torch
from transformers import BitsAndBytesConfig, pipeline
import re
import time
import random

DESCRIPTION = "# LLaVA πŸŒ‹πŸ’ͺ - Now with Arnold Mode!"

model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)
pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})

def extract_response_pairs(text):
    turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
    turns = [turn.strip() for turn in turns if turn.strip()]
    conv_list = []
    for i in range(0, len(turns[1::2]), 2):
        if i + 1 < len(turns[1::2]):
            conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
    return conv_list

def add_text(history, text):
    history = history + [[text, None]]
    return history, ""  # Clear the input field after submission

def arnold_speak(text):
    arnold_phrases = [
        "Come with me if you want to lift!",
        "I'll be back... after my protein shake.",
        "Hasta la vista, baby weight!",
        "Get to da choppa... I mean, da squat rack!",
        "You lack discipline! But don't worry, I'm here to pump you up!"
    ]
    
    text = text.replace(".", "!")  # More enthusiastic punctuation
    text = text.replace("gym", "iron paradise")
    text = text.replace("exercise", "pump iron")
    text = text.replace("workout", "sculpt your physique")
    
    # Add random Arnold phrase to the end
    text += " " + random.choice(arnold_phrases)
    
    return text

def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
    try:
        outputs = pipe(images=image, prompt=prompt,
                       generate_kwargs={"temperature": temperature,
                                        "length_penalty": length_penalty,
                                        "repetition_penalty": repetition_penalty,
                                        "max_length": max_length,
                                        "min_length": min_length,
                                        "top_p": top_p})
        inference_output = outputs[0]["generated_text"]
        return inference_output
    except Exception as e:
        print(f"Error during inference: {str(e)}")
        return f"An error occurred during inference: {str(e)}"

def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
    if text_input == "":
        yield history + [["Please input text", None]]
        return

    if image is None:
        yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]]
        return
    
    chat_history = " ".join([item for sublist in history for item in sublist if item is not None])  # Flatten history
    
    system_prompt = "You are a helpful AI assistant. " if not arnold_mode else "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation."
    
    prompt = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
    
    response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
    
    if arnold_mode:
        response = arnold_speak(response)
    
    history.append([text_input, ""])
    for i in range(len(response)):
        history[-1][1] = response[:i+1]
        time.sleep(0.05)
        yield history

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
    See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
    
    chatbot = gr.Chatbot()
    
    with gr.Row():
        image = gr.Image(type="pil")
        with gr.Column():
            text_input = gr.Textbox(label="Chat Input", lines=3)
            arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode")
    
    with gr.Accordion(label="Advanced settings", open=False):
        temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1)
        length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2)
        repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5)
        max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1)
        min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1)
        top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1)

    with gr.Row():
        clear_button = gr.Button("Clear")
        submit_button = gr.Button("Submit", variant="primary")

    submit_button.click(
        fn=bot,
        inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode],
        outputs=chatbot
    ).then(
        fn=lambda: "",
        outputs=text_input
    )

    clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False)

    examples = [
        ["./examples/baklava.png", "How to make this pastry?"],
        ["./examples/bee.png", "Describe this image."]
    ]
    gr.Examples(examples=examples, inputs=[image, text_input])

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)