File size: 4,994 Bytes
69698e1
 
786e086
7758cb9
69698e1
 
 
7758cb9
786e086
 
 
7758cb9
69698e1
 
 
 
786e086
 
 
 
69698e1
 
d889050
 
786e086
7758cb9
69698e1
786e086
 
 
 
7758cb9
69698e1
7758cb9
69698e1
7758cb9
69698e1
7758cb9
69698e1
786e086
69698e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7758cb9
786e086
 
 
 
7758cb9
 
786e086
69698e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e40a03
69698e1
 
6c82cee
69698e1
 
 
6c82cee
69698e1
 
 
cbf04ef
7758cb9
786e086
7758cb9
786e086
7758cb9
 
ec6a8d2
 
 
 
 
1d251b2
 
ec6a8d2
 
cbf04ef
 
1d251b2
ec6a8d2
cbf04ef
ec6a8d2
 
d889050
c5fa72e
ec6a8d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import os, gc
import torch
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download

ctx_limit = 3500
num_image_embeddings = 4096
title = 'ViusualRWKV-v5'
rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
vision_tower_name = 'openai/clip-vit-large-patch14-336'

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)

from modeling_vision import VisionEncoder, VisionEncoderConfig
from modeling_rwkv import RWKV
model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
model = RWKV(model=model_path, strategy='cpu fp32')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

##########################################################################
config = VisionEncoderConfig(n_embd=model.args.n_embd, 
                             vision_tower_name=vision_tower_name, 
                             grid_size=-1)
visual_encoder = VisionEncoder(config)
vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
vision_state_dict = torch.load(vision_local_path, map_location='cpu')
visual_encoder.load_state_dict(vision_state_dict)
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
##########################################################################
def generate_prompt(instruction):
    instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
    return f"\n{instruction}\n\nAssistant:"

def generate(
    ctx,
    image_features,
    token_count=200,
    temperature=1.0,
    top_p=0.7,
    presencePenalty = 0.1,
    countPenalty = 0.1,
):
    args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
                     alpha_frequency = countPenalty,
                     alpha_presence = presencePenalty,
                     token_ban = [], # ban the generation of some tokens
                     token_stop = [0]) # stop generation whenever you see any token here
    ctx = ctx.strip()
    all_tokens = []
    out_last = 0
    out_str = ''
    occurrence = {}
    for i in range(int(token_count)):
        if i == 0:
            input_ids = pipeline.encode(ctx)
            text_embs = model.w['emb.weight'][input_ids]
            input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
            out, state = model.forward(embs=input_embs, state=None)
        else:
            input_ids = [token]
            out, state = model.forward(input_ids, state)
        for n in occurrence:
            out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)

        token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
        if token in args.token_stop:
            break
        all_tokens += [token]
        for xxx in occurrence:
            occurrence[xxx] *= 0.996        
        if token not in occurrence:
            occurrence[token] = 1
        else:
            occurrence[token] += 1
        
        tmp = pipeline.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp:
            out_str += tmp
            yield out_str.strip()
            out_last = i + 1

    del out
    del state
    gc.collect()
    yield out_str.strip()


##########################################################################
cur_dir = os.path.dirname(os.path.abspath(__file__))
examples = [
    [
        f"{cur_dir}/examples_extreme_ironing.jpg",
        "What is unusual about this image?",
    ],
    [
        f"{cur_dir}/examples_waterview.jpg",
        "What are the things I should be cautious about when I visit here?",
    ]
]
def chatbot(image, question):
    image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
    image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
    input_text = generate_prompt(question)
    for output in generate(input_text, image_features):
        yield output

with gr.Blocks(title=title) as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(type='pil', label="Image")
        with gr.Column():
            prompt = gr.Textbox(lines=5, label="Prompt", 
                value="Please upload an image and ask a question.")
            with gr.Row():
                submit = gr.Button("Submit", variant="primary")
                clear = gr.Button("Clear", variant="secondary") 
        with gr.Column():
            output = gr.Textbox(label="Output", lines=7)
    data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
    submit.click(chatbot, [image, prompt], [output])
    clear.click(lambda: None, [], [output])
    data.click(lambda x: x, [data], [image, prompt])

demo.queue(max_size=10)
demo.launch(share=False)