howard-hou commited on
Commit
26f043f
·
verified ·
1 Parent(s): 1f97712

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -151
app.py CHANGED
@@ -8,6 +8,7 @@ import gc
8
  import gradio as gr
9
  import base64
10
  from io import BytesIO
 
11
  import torch
12
  import torch.nn.functional as F
13
  from datetime import datetime
@@ -19,121 +20,45 @@ gpu_h = nvmlDeviceGetHandleByIndex(0)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 3500
22
- ########################## text rwkv ################################################################
23
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
- title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
25
- model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
26
- model = RWKV(model=model_path, strategy='cuda fp16')
27
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
28
-
29
- def generate_prompt(instruction, input=""):
30
- instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
31
- input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
32
- if input:
33
- return f"""Instruction: {instruction}
34
- Input: {input}
35
- Response:"""
36
  else:
37
- return f"""User: hi
38
- Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
39
- User: {instruction}
40
- Assistant:"""
41
-
42
- def evaluate(
43
- ctx,
44
- token_count=200,
45
- temperature=1.0,
46
- top_p=0.7,
47
- presencePenalty = 0.1,
48
- countPenalty = 0.1,
49
- ):
50
- args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
51
- alpha_frequency = countPenalty,
52
- alpha_presence = presencePenalty,
53
- token_ban = [], # ban the generation of some tokens
54
- token_stop = [0]) # stop generation whenever you see any token here
55
- ctx = ctx.strip()
56
- all_tokens = []
57
- out_last = 0
58
- out_str = ''
59
- occurrence = {}
60
- state = None
61
- for i in range(int(token_count)):
62
- input_ids = pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token]
63
- out, state = model.forward(tokens=input_ids, state=state)
64
- for n in occurrence:
65
- out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
66
-
67
- token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
68
- if token in args.token_stop:
69
- break
70
- all_tokens += [token]
71
- for xxx in occurrence:
72
- occurrence[xxx] *= 0.996
73
- if token not in occurrence:
74
- occurrence[token] = 1
75
- else:
76
- occurrence[token] += 1
77
-
78
- tmp = pipeline.decode(all_tokens[out_last:])
79
- if '\ufffd' not in tmp:
80
- out_str += tmp
81
- yield out_str.strip()
82
- out_last = i + 1
83
-
84
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
85
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
86
- print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
87
- del out
88
- del state
89
- gc.collect()
90
- torch.cuda.empty_cache()
91
- yield out_str.strip()
92
-
93
- examples = [
94
- ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
95
- ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
96
- ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", 333, 1, 0.3, 0, 1],
97
- [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
98
- [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
99
- [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
100
- ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
101
- ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
102
- Player: Hello Edward. What have you been up to recently?
103
- Edward:''', 333, 1, 0.3, 0, 1],
104
- [generate_prompt("写一篇关于水利工程的流体力学模型的论文,需要详细全面。"), 333, 1, 0.3, 0, 1],
105
- ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然��以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。
106
- 但愿大宇宙能够忽略这个误差。
107
- 程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。
108
- “放心,我在,你们就在!”智子对两位人类朋友说。
109
- 聚变发动机启动了,推进器发出幽幽的蓝光,飞船缓缓地穿过了宇宙之门。
110
- 小宇宙中只剩下漂流瓶和生态球。漂流瓶隐没于黑暗里,在一千米见方的宇宙中,只有生态球里的小太阳发出一点光芒。在这个小小的生命世界中,几只清澈的水球在零重力环境中静静地飘浮着,有一条小鱼从一只水球中蹦出,跃入另一只水球,轻盈地穿游于绿藻之间。在一小块陆地上的草丛中,有一滴露珠从一片草叶上脱离,旋转着飘起,向太空中折射出一缕晶莹的阳光。''', 333, 1, 0.3, 0, 1],
111
- ]
112
-
113
- ########################## visual rwkv ################################################################
114
- visual_title = 'ViusualRWKV-v5'
115
- #rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
116
- #vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
117
- rwkv_remote_path = "rwkv3b-vitl336p14-577token_mix665k_8gpu_rwkv.pth"
118
- vision_remote_path = "rwkv3b-vitl336p14-577token_mix665k_8gpu_visual.pth"
119
  vision_tower_name = 'openai/clip-vit-large-patch14-336'
120
 
121
- model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
122
- visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
 
123
 
124
  ##########################################################################
125
  from modeling_vision import VisionEncoder, VisionEncoderConfig
126
- config = VisionEncoderConfig(n_embd=visual_rwkv.args.n_embd,
127
  vision_tower_name=vision_tower_name,
128
  grid_size=-1)
129
  visual_encoder = VisionEncoder(config)
130
- vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=vision_remote_path)
131
  vision_state_dict = torch.load(vision_local_path, map_location='cpu')
132
  visual_encoder.load_state_dict(vision_state_dict, strict=False)
133
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
134
  visual_encoder = visual_encoder.to(device)
135
  ##########################################################################
136
- def visual_generate_prompt(instruction):
137
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
138
  return f"\n{instruction}\n\nAssistant:"
139
 
@@ -159,10 +84,10 @@ def generate(
159
  for i in range(int(token_count)):
160
  if i == 0:
161
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
162
- out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
163
  else:
164
  input_ids = [token]
165
- out, state = visual_rwkv.forward(tokens=input_ids, state=state)
166
  for n in occurrence:
167
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
168
 
@@ -195,7 +120,7 @@ def generate(
195
 
196
  ##########################################################################
197
  cur_dir = os.path.dirname(os.path.abspath(__file__))
198
- visual_examples = [
199
  [
200
  f"{cur_dir}/examples_pizza.jpg",
201
  "What are steps to cook it?"
@@ -205,8 +130,12 @@ visual_examples = [
205
  "what is the name of this bird?",
206
  ],
207
  [
208
- f"{cur_dir}/examples_woman_and_dog.png",
209
- "describe this image",
 
 
 
 
210
  ],
211
  ]
212
 
@@ -219,21 +148,22 @@ def pil_image_to_base64(pil_image):
219
  return base64_image
220
 
221
  image_cache = {}
222
- ln0_weight = visual_rwkv.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
223
- ln0_bias = visual_rwkv.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
224
  def compute_image_state(image):
225
  base64_image = pil_image_to_base64(image)
226
  if base64_image in image_cache:
227
  image_state = image_cache[base64_image]
228
  else:
229
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
 
230
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
231
  # apply layer norm to image feature, very important
232
  image_features = F.layer_norm(image_features,
233
  (image_features.shape[-1],),
234
  weight=ln0_weight,
235
  bias=ln0_bias)
236
- _, image_state = visual_rwkv.forward(embs=image_features, state=None)
237
  image_cache[base64_image] = image_state
238
  return image_state
239
 
@@ -242,49 +172,25 @@ def chatbot(image, question):
242
  yield "Please upload an image."
243
  return
244
  image_state = compute_image_state(image)
245
- input_text = visual_generate_prompt(question)
246
- for output in generate(input_text, image_state, temperature=1.0):
247
  yield output
248
 
249
-
250
- ##################################################################################################################
251
  with gr.Blocks(title=title) as demo:
252
- gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
253
- with gr.Tab("Raw Generation"):
254
- gr.Markdown(f"This is [RWKV-5 World v2](https://huggingface.co/BlinkDL/rwkv-5-world) with 1.5B params - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports all 100+ world languages and code. And we have [200+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.")
255
- with gr.Row():
256
- with gr.Column():
257
- prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: Sure! Here is a very detailed plan to create flying pigs:")
258
- token_count = gr.Slider(10, 333, label="Max Tokens", step=10, value=333)
259
- temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
260
- top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
261
- presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
262
- count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=1)
263
- with gr.Column():
264
- with gr.Row():
265
- submit = gr.Button("Submit", variant="primary")
266
- clear = gr.Button("Clear", variant="secondary")
267
- output = gr.Textbox(label="Output", lines=5)
268
- data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
269
- submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
270
- clear.click(lambda: None, [], [output])
271
- data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
272
- with gr.Tab("Visual RWKV"):
273
- with gr.Row():
274
- with gr.Column():
275
- image = gr.Image(type='pil', label="Image")
276
- with gr.Column():
277
- prompt = gr.Textbox(lines=8, label="Prompt",
278
- value="Render a clear and concise summary of the photo.")
279
- with gr.Row():
280
- submit = gr.Button("Submit", variant="primary")
281
- clear = gr.Button("Clear", variant="secondary")
282
- with gr.Column():
283
- output = gr.Textbox(label="Output", lines=10)
284
- data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
285
- submit.click(chatbot, [image, prompt], [output])
286
- clear.click(lambda: None, [], [output])
287
- data.click(lambda x: x, [data], [image, prompt])
288
 
289
- demo.queue(concurrency_count=1, max_size=10)
290
  demo.launch(share=False)
 
8
  import gradio as gr
9
  import base64
10
  from io import BytesIO
11
+ from pathlib import Path
12
  import torch
13
  import torch.nn.functional as F
14
  from datetime import datetime
 
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  ctx_limit = 3500
23
+ title = 'ViusualRWKV-v6.0'
24
+ visualrwkv_remote_path = "VisualRWKV-v060-1B6-v1.0-20240612.pth"
25
+ model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-6", filename=visualrwkv_remote_path)
26
+ # convert visualrwkv to RWKV and vision encoder #######################
27
+ output_dir = Path(model_path).parent
28
+ state_dict = torch.load(model_path, map_location="cpu")
29
+ rwkv_state_dict = {}
30
+ visual_state_dict = {}
31
+ for key in state_dict:
32
+ if key.startswith("rwkv"):
33
+ rwkv_state_dict[key[5:]] = state_dict[key].half()
 
 
 
34
  else:
35
+ visual_state_dict[key] = state_dict[key].half()
36
+ print("rwkv state dict has keys: ", len(rwkv_state_dict))
37
+ print("visual state dict has keys: ", len(visual_state_dict))
38
+ # save
39
+ vision_local_path = output_dir / f"visual.pth"
40
+ rwkv_local_path = output_dir / f"rwkv.pth"
41
+ torch.save(rwkv_state_dict, rwkv_local_path)
42
+ torch.save(visual_state_dict, vision_local_path)
43
+ ##########################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  vision_tower_name = 'openai/clip-vit-large-patch14-336'
45
 
46
+ model = RWKV(model=rwkv_local_path, strategy='cuda fp16')
47
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
48
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
49
 
50
  ##########################################################################
51
  from modeling_vision import VisionEncoder, VisionEncoderConfig
52
+ config = VisionEncoderConfig(n_embd=model.args.n_embd,
53
  vision_tower_name=vision_tower_name,
54
  grid_size=-1)
55
  visual_encoder = VisionEncoder(config)
 
56
  vision_state_dict = torch.load(vision_local_path, map_location='cpu')
57
  visual_encoder.load_state_dict(vision_state_dict, strict=False)
58
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
59
  visual_encoder = visual_encoder.to(device)
60
  ##########################################################################
61
+ def generate_prompt(instruction):
62
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
63
  return f"\n{instruction}\n\nAssistant:"
64
 
 
84
  for i in range(int(token_count)):
85
  if i == 0:
86
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
87
+ out, state = model.forward(tokens=input_ids, state=image_state)
88
  else:
89
  input_ids = [token]
90
+ out, state = model.forward(tokens=input_ids, state=state)
91
  for n in occurrence:
92
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
93
 
 
120
 
121
  ##########################################################################
122
  cur_dir = os.path.dirname(os.path.abspath(__file__))
123
+ examples = [
124
  [
125
  f"{cur_dir}/examples_pizza.jpg",
126
  "What are steps to cook it?"
 
130
  "what is the name of this bird?",
131
  ],
132
  [
133
+ f"{cur_dir}/examples_extreme_ironing.jpg",
134
+ "What is unusual about this image?",
135
+ ],
136
+ [
137
+ f"{cur_dir}/examples_waterview.jpg",
138
+ "What are the things I should be cautious about when I visit here?",
139
  ],
140
  ]
141
 
 
148
  return base64_image
149
 
150
  image_cache = {}
151
+ ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
152
+ ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
153
  def compute_image_state(image):
154
  base64_image = pil_image_to_base64(image)
155
  if base64_image in image_cache:
156
  image_state = image_cache[base64_image]
157
  else:
158
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
159
+ image = image.to(device)
160
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
161
  # apply layer norm to image feature, very important
162
  image_features = F.layer_norm(image_features,
163
  (image_features.shape[-1],),
164
  weight=ln0_weight,
165
  bias=ln0_bias)
166
+ _, image_state = model.forward(embs=image_features, state=None)
167
  image_cache[base64_image] = image_state
168
  return image_state
169
 
 
172
  yield "Please upload an image."
173
  return
174
  image_state = compute_image_state(image)
175
+ input_text = generate_prompt(question)
176
+ for output in generate(input_text, image_state):
177
  yield output
178
 
 
 
179
  with gr.Blocks(title=title) as demo:
180
+ with gr.Row():
181
+ with gr.Column():
182
+ image = gr.Image(type='pil', label="Image")
183
+ with gr.Column():
184
+ prompt = gr.Textbox(lines=8, label="Prompt",
185
+ value="Render a clear and concise summary of the photo.")
186
+ with gr.Row():
187
+ submit = gr.Button("Submit", variant="primary")
188
+ clear = gr.Button("Clear", variant="secondary")
189
+ with gr.Column():
190
+ output = gr.Textbox(label="Output", lines=10)
191
+ data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
192
+ submit.click(chatbot, [image, prompt], [output])
193
+ clear.click(lambda: None, [], [output])
194
+ data.click(lambda x: x, [data], [image, prompt])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
 
196
  demo.launch(share=False)