Spaces:
Runtime error
Runtime error
File size: 6,617 Bytes
fc28316 69698e1 794ada2 786e086 c25fbe0 7758cb9 69698e1 9d6c917 69698e1 786e086 7758cb9 69698e1 786e086 9d6c917 69698e1 d889050 fc28316 786e086 7758cb9 69698e1 786e086 7758cb9 9d6c917 69698e1 7758cb9 69698e1 b0d85ba 69698e1 7758cb9 69698e1 b0d85ba ed80f94 fbe0b0a 6a6acff 43e9368 69698e1 8c28418 69698e1 7758cb9 b0d85ba 7758cb9 7fa18f3 69698e1 9d6c917 69698e1 9d6c917 69698e1 4e40a03 69698e1 ed80f94 4a3a6ea ed80f94 4a3a6ea 69698e1 6c82cee 69698e1 6c82cee 69698e1 4a3a6ea 69698e1 c25fbe0 794ada2 b0d85ba 794ada2 b0d85ba 794ada2 b0d85ba 794ada2 cbf04ef 7fa18f3 b0d85ba 7758cb9 b0d85ba 7758cb9 ec6a8d2 ed80f94 c25fbe0 ec6a8d2 cbf04ef ed80f94 ec6a8d2 9d6c917 ec6a8d2 d889050 58c6fa7 ec6a8d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import os
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from modeling_rwkv import RWKV
import gc
import gradio as gr
import base64
from io import BytesIO
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
ctx_limit = 3500
title = 'ViusualRWKV-v5'
rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
vision_tower_name = 'openai/clip-vit-large-patch14-336'
model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
model = RWKV(model=model_path, strategy='cuda fp16')
from rwkv.utils import PIPELINE, PIPELINE_ARGS
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
##########################################################################
from modeling_vision import VisionEncoder, VisionEncoderConfig
config = VisionEncoderConfig(n_embd=model.args.n_embd,
vision_tower_name=vision_tower_name,
grid_size=-1)
visual_encoder = VisionEncoder(config)
vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
vision_state_dict = torch.load(vision_local_path, map_location='cpu')
visual_encoder.load_state_dict(vision_state_dict)
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
if torch.cuda.is_available():
visual_encoder = visual_encoder.cuda()
##########################################################################
def generate_prompt(instruction):
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
return f"\n{instruction}\n\nAssistant:"
def generate(
ctx,
image_state,
token_count=200,
temperature=0.2,
top_p=0.3,
presencePenalty = 0.0,
countPenalty = 1.0,
):
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0, 261]) # stop generation whenever you see any token here
ctx = ctx.strip()
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
for i in range(int(token_count)):
if i == 0:
input_ids = pipeline.encode(ctx)[-ctx_limit:]
out, state = model.forward(tokens=input_ids, state=image_state)
else:
input_ids = [token]
out, state = model.forward(tokens=input_ids, state=state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
break
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= 0.996
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
tmp = pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
yield out_str.strip()
out_last = i + 1
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
del out
del state
gc.collect()
torch.cuda.empty_cache()
yield out_str.strip()
##########################################################################
cur_dir = os.path.dirname(os.path.abspath(__file__))
examples = [
[
f"{cur_dir}/examples_pizza.jpg",
"What are steps to cook it?"
],
[
f"{cur_dir}/examples_bluejay.jpg",
"what is the name of this bird?",
],
[
f"{cur_dir}/examples_extreme_ironing.jpg",
"What is unusual about this image?",
],
[
f"{cur_dir}/examples_waterview.jpg",
"What are the things I should be cautious about when I visit here?",
],
]
def pil_image_to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
# Encodes the image data into base64 format as a bytes object
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
return base64_image
image_cache = {}
def compute_image_state(image):
base64_image = pil_image_to_base64(image)
if base64_image in image_cache:
image_state = image_cache[base64_image]
else:
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
# apply layer norm to image feature, very important
image_features = F.layer_norm(image_features,
(image_features.shape[-1],),
weight=model.w['blocks.0.ln0.weight'],
bias=model.w['blocks.0.ln0.bias'])
_, image_state = model.forward(embs=image_features, state=None)
image_cache[base64_image] = image_state
return image_state
def chatbot(image, question):
if image is None:
yield "Please upload an image."
return
image_state = compute_image_state(image)
input_text = generate_prompt(question)
for output in generate(input_text, image_state):
yield output
with gr.Blocks(title=title) as demo:
with gr.Row():
with gr.Column():
image = gr.Image(type='pil', label="Image")
with gr.Column():
prompt = gr.Textbox(lines=8, label="Prompt",
value="Render a clear and concise summary of the photo.")
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear", variant="secondary")
with gr.Column():
output = gr.Textbox(label="Output", lines=10)
data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
submit.click(chatbot, [image, prompt], [output])
clear.click(lambda: None, [], [output])
data.click(lambda x: x, [data], [image, prompt])
demo.queue(max_size=10)
demo.launch(share=False) |