File size: 2,674 Bytes
746855d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2f8664
 
746855d
 
 
7fcce9b
746855d
 
 
 
 
 
 
 
 
 
 
 
 
b2f8664
 
 
746855d
 
 
 
 
 
 
 
 
 
 
 
b2f8664
746855d
b2f8664
746855d
b2f8664
 
 
 
 
 
 
 
 
 
746855d
 
b2f8664
746855d
b2f8664
 
 
 
 
 
 
 
 
a362689
746855d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations

import spaces

import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os

from transformers import AutoModel, AutoProcessor
import torch
import sys
import subprocess
from PIL import Image

import time

subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'packaging'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'ninja'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'mamba-ssm'])
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'causal-conv1d'])

from cobra import load
vlm = load("cobra+3b")

if torch.cuda.is_available():
    DEVICE = "cuda"
    DTYPE = torch.bfloat16
else:
    DEVICE = "cpu"
    DTYPE = torch.float32
vlm.to(DEVICE, dtype=DTYPE)

prompt_builder = vlm.get_prompt_builder()

@spaces.GPU
def bot_streaming(message, history, temperature, top_k, max_new_tokens):
    if len(history) == 0:
        prompt_builder.prompt, prompt_builder.turn_count = "", 0
    print(message)
    if message["files"]:
        image = message["files"][-1]["path"]
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        # kept inside tuples, take the last one
        for hist in history:
            if type(hist[0])==tuple:
                image = hist[0][0]
  
    image = Image.open(image).convert("RGB")
    
    prompt_builder.add_turn(role="human", message=message['text'])
    prompt_text = prompt_builder.get_prompt()
    
    # Generate from the VLM
    with torch.no_grad():
        generated_text = vlm.generate(
            image,
            prompt_text,
            cg=True,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            max_new_tokens=max_new_tokens,
        )
    prompt_builder.add_turn(role="gpt", message=generated_text)

    time.sleep(0.04)
    yield generated_text
    
    
demo = gr.ChatInterface(fn=bot_streaming, 
                        additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"),
                                           gr.Slider(1, 3, value=1, step=1, label="Top k"),
                                           gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")],
                        title="Cobra", 
                        description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it.",
                        stop_btn="Stop Generation", multimodal=True,
                        examples=[[{"text": "Describe this image", "files":["./cobra.png"]}]])
demo.launch(debug=True)