Spaces:
Sleeping
Sleeping
File size: 4,450 Bytes
c5b3aef c98b207 5cd56f1 c98b207 be961e6 c98b207 09399fd c98b207 09399fd c98b207 cbb017b c98b207 09399fd c98b207 5adecab 2692054 90b9de8 50517e2 c98b207 6cb9c18 332c0b2 c98b207 2692054 6904764 90b9de8 67d3fd3 6cb9c18 4234c1d 67d3fd3 6cb9c18 573848c 6cb9c18 90b9de8 6904764 999df98 c98b207 3090ac5 c98b207 09399fd c98b207 6904764 09399fd c98b207 09399fd be961e6 09399fd c98b207 cf7a112 e7455bb 60e7596 a927087 60e7596 e7455bb c98b207 0d0766f c98b207 a927087 c98b207 a927087 c98b207 bb45d22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import subprocess
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
from threading import Thread
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import os
import time
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
eos_token_id=processor.tokenizer.eos_token_id
@spaces.GPU(queue=False)
def stream_chat(message, history: list, temperature: float, max_new_tokens: int):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
if message["files"]:
image = Image.open(message["files"][-1])
conversation.append({"role": "user", "content": f"<|image_1|>\n{message['text']}"})
else:
if len(history) == 0:
raise gr.Error("Please upload an image first.")
image = None
else:
image = Image.open(history[0][0][0])
for prompt, answer in history:
if answer is None:
conversation.extend([{"role": "user", "content":"<|image_1|>"},{"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message['text']})
print(f"Conversation is -\n{conversation}")
inputs = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs_ids = processor(inputs, image, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
generate_kwargs = dict(
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
eos_token_id=eos_token_id,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
generate_kwargs = {**inputs_ids, **generate_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
EXAMPLES = [
[{"text": "What is on the desk?", "files": ["./laptop.jpg"]}],
[{"text": "Where it is?", "files": ["./hotel.jpg"]}],
[{"text": "Can yo describe this image?", "files": ["./spacecat.png"]}]
]
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
),
gr.Examples(EXAMPLES,[chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |