Spaces:
Runtime error
Runtime error
File size: 4,033 Bytes
c5b3aef c98b207 5cd56f1 c98b207 be961e6 c98b207 09399fd c98b207 09399fd c98b207 cbb017b c98b207 09399fd c98b207 5adecab 2692054 c98b207 2692054 c98b207 3193581 09399fd c98b207 3193581 c98b207 2692054 09399fd 3193581 09399fd d18aa37 5cd56f1 c98b207 09399fd c98b207 09399fd c98b207 09399fd be961e6 09399fd c98b207 cf7a112 e7455bb c98b207 e7455bb c98b207 1f4086f c98b207 0d0766f c98b207 bb45d22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import subprocess
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
from threading import Thread
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import os
import time
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
eos_token_id=processor.tokenizer.eos_token_id
@spaces.GPU(queue=False)
def stream_chat(message, history: list, temperature: float, max_new_tokens: int):
print(message)
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
if message["files"]:
image = Image.open(message["files"][-1])
conversation.append({"role": "user"}, "content": f"<|image_1|>\n{message['text']}")
else:
if len(history) == 0:
gr.Error("Please upload an image first.")
image = None
conversation.append({"role": "user", "content": message['text']})
prompt = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, images=image, return_tensors="pt").to(0)
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
eos_token_id=eos_token_id,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
generate_kwargs = {**inputs, **generate_kwargs}
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
examples=[[{"text": "What is on the desk?", "files": ["./laptop.jpg"]}],
[{"text": "Where it is?", "files": ["./hotel.jpg"]}],
[{"text": "Can yo describe this image?", "files": ["./spacecat.png"]}]],
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="βοΈ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |