Spaces:
Runtime error
Runtime error
File size: 2,603 Bytes
746855d b2f8664 746855d 7fcce9b 746855d de7186d 746855d 9cc868e 746855d b2f8664 746855d b2f8664 746855d b2f8664 746855d b2f8664 746855d b2f8664 746855d b2f8664 9ca6512 746855d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from __future__ import annotations
import spaces
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os
from transformers import AutoModel, AutoProcessor
import torch
import sys
import subprocess
from PIL import Image
import time
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'packaging'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'ninja'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d'])
from cobra import load
vlm = load("cobra+3b")
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.float32
else:
DEVICE = "cpu"
DTYPE = torch.float32
vlm.enable_mixed_precision_training = False
vlm.to(DEVICE, dtype=DTYPE)
prompt_builder = vlm.get_prompt_builder()
@spaces.GPU
def bot_streaming(message, history, temperature, top_k, max_new_tokens):
if len(history) == 0:
prompt_builder.prompt, prompt_builder.turn_count = "", 0
if message["files"]:
image = message["files"][-1]["path"]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0])==tuple:
image = hist[0][0]
image = Image.open(image).convert("RGB")
prompt_builder.add_turn(role="human", message=message['text'])
prompt_text = prompt_builder.get_prompt()
# Generate from the VLM
with torch.no_grad():
generated_text = vlm.generate(
image,
prompt_text,
cg=True,
do_sample=True,
temperature=temperature,
top_k=top_k,
max_new_tokens=max_new_tokens,
)
prompt_builder.add_turn(role="gpt", message=generated_text)
time.sleep(0.04)
yield generated_text
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"),
gr.Slider(1, 3, value=1, step=1, label="Top k"),
gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")],
title="Cobra",
description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.",
stop_btn="Stop Generation", multimodal=True)
demo.launch(debug=True) |