File size: 4,544 Bytes
71e86f7
 
 
114dd13
71e86f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import transformers
import warnings
import time
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from PIL import Image
from threading import Thread


transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings("ignore")


device = "cuda"  # or cpu
torch.set_default_device(device)

model_name = "BAAI/Bunny-v1_1-Llama-3-8B-V"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16, # float32 for cpu
    device_map="auto",
    trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True)


@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        # message["files"][-1] is a Dict or just a string
        if type(message["files"][-1]) == dict:
            image_file = message["files"][-1]["path"]
        else:
            image_file = message["files"][-1]
    else:
        image_file = None
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0]) == tuple:
                image_file = hist[0][0]


    prompt = message["text"]
    if image_file is None:
        text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"
        input_ids = torch.tensor(tokenizer(text).input_ids, dtype=torch.long).unsqueeze(0).to(device)
    else:
        text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
        text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
        input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1][1:], dtype=torch.long).unsqueeze(0).to(device)
        
    if image_file is not None:
        image = Image.open(image_file)
        image_tensor = model.process_images([image], model.config).to(dtype=model.dtype, device=device)
    else:
        image_tensor = None

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)

    thread = Thread(target=model.generate, kwargs=dict(
            inputs=input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.2,
            top_p=0.7,
            max_new_tokens=512,
            streamer=streamer,
            use_cache=True,
            repetition_penalty=1.08
        ))
    thread.start()

    buffer = ""
    time.sleep(0.5)
    for new_text in streamer:
        if "<|end_of_text|>" in new_text:
            new_text = new_text.split("<|end_of_text|>")[0]
        buffer += new_text

        # generated_text_without_prompt = buffer[len(text_prompt):]
        generated_text_without_prompt = buffer
        # print(generated_text_without_prompt)
        time.sleep(0.06)
        # print(f"new_text: {generated_text_without_prompt}")
        yield generated_text_without_prompt


title_markdown = ("""
# 🐰 Bunny: A family of lightweight multimodal models

[πŸ“– [Technical report](https://arxiv.org/abs/2402.11530)] | [🏠 [Code](https://github.com/BAAI-DCAI/Bunny)] | [πŸ€— [Bunny-v1.1-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-v1_1-Llama-3-8B-V)] | [πŸ€— [Bunny-v1.1-4B](https://huggingface.co/BAAI/Bunny-v1_1-4B)] | [πŸ€— [Bunny-v1.0-3B](https://huggingface.co/BAAI/Bunny-v1_0-3B)]

""")

chatbot = gr.Chatbot(
    elem_id="chatbot",
    label="Bunny-v1.1-Llama-3-8B-V",
    avatar_images=[f"./assets/user.png", f"./assets/icon.jpg"],
    height=550
    )

chat_input = gr.MultimodalTextbox(
    interactive=True,
    file_types=["image"],
    placeholder="Enter message or upload file...",
    show_label=False
)

with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(title_markdown)

    gr.ChatInterface(
        fn=bot_streaming,
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot
    )

    gr.Examples(examples=[{"text": "What is the astronaut holding in his hand?", "files": ["./assets/example_1.png"]},
            {"text": "Why is the image funny?", "files": ["./assets/example_2.png"]}], inputs=chat_input)


demo.queue(api_open=False)
demo.launch(show_api=False, share=False)