Spaces:
Sleeping
Sleeping
File size: 4,994 Bytes
69698e1 786e086 7758cb9 69698e1 7758cb9 786e086 7758cb9 69698e1 786e086 69698e1 d889050 786e086 7758cb9 69698e1 786e086 7758cb9 69698e1 7758cb9 69698e1 7758cb9 69698e1 7758cb9 69698e1 786e086 69698e1 7758cb9 786e086 7758cb9 786e086 69698e1 4e40a03 69698e1 6c82cee 69698e1 6c82cee 69698e1 cbf04ef 7758cb9 786e086 7758cb9 786e086 7758cb9 ec6a8d2 1d251b2 ec6a8d2 cbf04ef 1d251b2 ec6a8d2 cbf04ef ec6a8d2 d889050 c5fa72e ec6a8d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import os, gc
import torch
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
ctx_limit = 3500
num_image_embeddings = 4096
title = 'ViusualRWKV-v5'
rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
vision_tower_name = 'openai/clip-vit-large-patch14-336'
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # if '1' then use CUDA kernel for seq mode (much faster)
from modeling_vision import VisionEncoder, VisionEncoderConfig
from modeling_rwkv import RWKV
model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
model = RWKV(model=model_path, strategy='cpu fp32')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
##########################################################################
config = VisionEncoderConfig(n_embd=model.args.n_embd,
vision_tower_name=vision_tower_name,
grid_size=-1)
visual_encoder = VisionEncoder(config)
vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
vision_state_dict = torch.load(vision_local_path, map_location='cpu')
visual_encoder.load_state_dict(vision_state_dict)
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
##########################################################################
def generate_prompt(instruction):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
return f"\n{instruction}\n\nAssistant:"
def generate(
ctx,
image_features,
token_count=200,
temperature=1.0,
top_p=0.7,
presencePenalty = 0.1,
countPenalty = 0.1,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0]) # stop generation whenever you see any token here
ctx = ctx.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(int(token_count)):
if i == 0:
input_ids = pipeline.encode(ctx)
text_embs = model.w['emb.weight'][input_ids]
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
out, state = model.forward(embs=input_embs, state=None)
else:
input_ids = [token]
out, state = model.forward(input_ids, state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
del out
del state
gc.collect()
yield out_str.strip()
##########################################################################
cur_dir = os.path.dirname(os.path.abspath(__file__))
examples = [
[
f"{cur_dir}/examples_extreme_ironing.jpg",
"What is unusual about this image?",
],
[
f"{cur_dir}/examples_waterview.jpg",
"What are the things I should be cautious about when I visit here?",
]
]
def chatbot(image, question):
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
input_text = generate_prompt(question)
for output in generate(input_text, image_features):
yield output
with gr.Blocks(title=title) as demo:
with gr.Row():
with gr.Column():
image = gr.Image(type='pil', label="Image")
with gr.Column():
prompt = gr.Textbox(lines=5, label="Prompt",
value="Please upload an image and ask a question.")
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear", variant="secondary")
with gr.Column():
output = gr.Textbox(label="Output", lines=7)
data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
submit.click(chatbot, [image, prompt], [output])
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [image, prompt])
demo.queue(max_size=10)
demo.launch(share=False) |