howard-hou commited on
Commit
cd5a657
·
1 Parent(s): 964ee2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -33
app.py CHANGED
@@ -19,15 +19,105 @@ gpu_h = nvmlDeviceGetHandleByIndex(0)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 3500
22
- title = 'ViusualRWKV-v5'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
24
  vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
25
  vision_tower_name = 'openai/clip-vit-large-patch14-336'
26
 
27
  model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
28
- model = RWKV(model=model_path, strategy='cuda fp16')
29
- from rwkv.utils import PIPELINE, PIPELINE_ARGS
30
- pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
31
 
32
  ##########################################################################
33
  from modeling_vision import VisionEncoder, VisionEncoderConfig
@@ -41,7 +131,7 @@ visual_encoder.load_state_dict(vision_state_dict)
41
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
42
  visual_encoder = visual_encoder.to(device)
43
  ##########################################################################
44
- def generate_prompt(instruction):
45
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
46
  return f"\n{instruction}\n\nAssistant:"
47
 
@@ -67,10 +157,10 @@ def generate(
67
  for i in range(int(token_count)):
68
  if i == 0:
69
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
70
- out, state = model.forward(tokens=input_ids, state=image_state)
71
  else:
72
  input_ids = [token]
73
- out, state = model.forward(tokens=input_ids, state=state)
74
  for n in occurrence:
75
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
76
 
@@ -103,7 +193,7 @@ def generate(
103
 
104
  ##########################################################################
105
  cur_dir = os.path.dirname(os.path.abspath(__file__))
106
- examples = [
107
  [
108
  f"{cur_dir}/examples_pizza.jpg",
109
  "What are steps to cook it?"
@@ -131,22 +221,19 @@ def pil_image_to_base64(pil_image):
131
  return base64_image
132
 
133
  image_cache = {}
134
- ln0_weight = model.w['blocks.0.ln0.weight'].to(torch.float32).to(device)
135
- ln0_bias = model.w['blocks.0.ln0.bias'].to(torch.float32).to(device)
136
  def compute_image_state(image):
137
  base64_image = pil_image_to_base64(image)
138
  if base64_image in image_cache:
139
  image_state = image_cache[base64_image]
140
  else:
141
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
142
- image = image.to(device)
143
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
144
  # apply layer norm to image feature, very important
145
  image_features = F.layer_norm(image_features,
146
  (image_features.shape[-1],),
147
- weight=ln0_weight,
148
- bias=ln0_bias)
149
- _, image_state = model.forward(embs=image_features, state=None)
150
  image_cache[base64_image] = image_state
151
  return image_state
152
 
@@ -155,26 +242,49 @@ def chatbot(image, question):
155
  yield "Please upload an image."
156
  return
157
  image_state = compute_image_state(image)
158
- input_text = generate_prompt(question)
159
  for output in generate(input_text, image_state):
160
  yield output
161
 
 
 
162
  with gr.Blocks(title=title) as demo:
163
- with gr.Row():
164
- with gr.Column():
165
- image = gr.Image(type='pil', label="Image")
166
- with gr.Column():
167
- prompt = gr.Textbox(lines=8, label="Prompt",
168
- value="Render a clear and concise summary of the photo.")
169
- with gr.Row():
170
- submit = gr.Button("Submit", variant="primary")
171
- clear = gr.Button("Clear", variant="secondary")
172
- with gr.Column():
173
- output = gr.Textbox(label="Output", lines=10)
174
- data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
175
- submit.click(chatbot, [image, prompt], [output])
176
- clear.click(lambda: None, [], [output])
177
- data.click(lambda x: x, [data], [image, prompt])
178
-
179
- demo.queue(max_size=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  demo.launch(share=False)
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  ctx_limit = 3500
22
+ ########################## text rwkv ################################################################
23
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
24
+ from rwkv.model import RWKV
25
+ title = "RWKV-5-World-1B5-v2-20231025-ctx4096"
26
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.pth")
27
+ model = RWKV(model=model_path, strategy='cuda fp16')
28
+ pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
29
+
30
+ def generate_prompt(instruction, input=""):
31
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
32
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
33
+ if input:
34
+ return f"""Instruction: {instruction}
35
+ Input: {input}
36
+ Response:"""
37
+ else:
38
+ return f"""User: hi
39
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
40
+ User: {instruction}
41
+ Assistant:"""
42
+
43
+ def evaluate(
44
+ ctx,
45
+ token_count=200,
46
+ temperature=1.0,
47
+ top_p=0.7,
48
+ presencePenalty = 0.1,
49
+ countPenalty = 0.1,
50
+ ):
51
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
52
+ alpha_frequency = countPenalty,
53
+ alpha_presence = presencePenalty,
54
+ token_ban = [], # ban the generation of some tokens
55
+ token_stop = [0]) # stop generation whenever you see any token here
56
+ ctx = ctx.strip()
57
+ all_tokens = []
58
+ out_last = 0
59
+ out_str = ''
60
+ occurrence = {}
61
+ state = None
62
+ for i in range(int(token_count)):
63
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
64
+ for n in occurrence:
65
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
66
+
67
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
68
+ if token in args.token_stop:
69
+ break
70
+ all_tokens += [token]
71
+ for xxx in occurrence:
72
+ occurrence[xxx] *= 0.996
73
+ if token not in occurrence:
74
+ occurrence[token] = 1
75
+ else:
76
+ occurrence[token] += 1
77
+
78
+ tmp = pipeline.decode(all_tokens[out_last:])
79
+ if '\ufffd' not in tmp:
80
+ out_str += tmp
81
+ yield out_str.strip()
82
+ out_last = i + 1
83
+
84
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
85
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
86
+ print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
87
+ del out
88
+ del state
89
+ gc.collect()
90
+ torch.cuda.empty_cache()
91
+ yield out_str.strip()
92
+
93
+ examples = [
94
+ ["Assistant: Sure! Here is a very detailed plan to create flying pigs:", 333, 1, 0.3, 0, 1],
95
+ ["Assistant: Sure! Here are some ideas for FTL drive:", 333, 1, 0.3, 0, 1],
96
+ ["A few light taps upon the pane made her turn to the window. It had begun to snow again.", 333, 1, 0.3, 0, 1],
97
+ [generate_prompt("Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires."), 333, 1, 0.3, 0, 1],
98
+ [generate_prompt("東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。"), 333, 1, 0.3, 0, 1],
99
+ [generate_prompt("Write a story using the following information.", "A man named Alex chops a tree down."), 333, 1, 0.3, 0, 1],
100
+ ["Assistant: Here is a very detailed plan to kill all mosquitoes:", 333, 1, 0.3, 0, 1],
101
+ ['''Edward: I am Edward Elric from fullmetal alchemist. I am in the world of full metal alchemist and know nothing of the real world.
102
+ Player: Hello Edward. What have you been up to recently?
103
+ Edward:''', 333, 1, 0.3, 0, 1],
104
+ [generate_prompt("写一篇关于水利工程的流体力学模型的论文,需要详细全面。"), 333, 1, 0.3, 0, 1],
105
+ ['''“当然可以,大宇宙不会因为这五公斤就不坍缩了。”关一帆说,他还有一个没说出来的想法:也许大宇宙真的会因为相差一个原子的质量而由封闭转为开放。大自然的精巧有时超出想象,比如生命的诞生,就需要各项宇宙参数在几亿亿分之一精度上的精确配合。但程心仍然可以留下她的生态球,因为在那无数文明创造的无数小宇宙中,肯定有相当一部分不响应回归运动的号召,所以,大宇宙最终被夺走的质量至少有几亿吨,甚至可能是几亿亿亿吨。
106
+ 但愿大宇宙能够忽略这个误差。
107
+ 程心和关一帆进入了飞船,智子最后也进来了。她早就不再穿那身华丽的和服了,她现在身着迷彩服,再次成为一名轻捷精悍的战士,她的身上佩带着许多武器和生存装备,最引人注目的是那把插在背后的武士刀。
108
+ “放心,我在,你们就在!”智子对两位人类朋友说。
109
+ 聚变发动机启动了,推进器发出幽幽的蓝光,���船缓缓地穿过了宇宙之门。
110
+ 小宇宙中只剩下漂流瓶和生态球。漂流瓶隐没于黑暗里,在一千米见方的宇宙中,只有生态球里的小太阳发出一点光芒。在这个小小的生命世界中,几只清澈的水球在零重力环境中静静地飘浮着,有一条小鱼从一只水球中蹦出,跃入另一只水球,轻盈地穿游于绿藻之间。在一小块陆地上的草丛中,有一滴露珠从一片草叶上脱离,旋转着飘起,向太空中折射出一缕晶莹的阳光。''', 333, 1, 0.3, 0, 1],
111
+ ]
112
+
113
+ ########################## visual rwkv ################################################################
114
+ visual_title = 'ViusualRWKV-v5'
115
  rwkv_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_rwkv.pth"
116
  vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
117
  vision_tower_name = 'openai/clip-vit-large-patch14-336'
118
 
119
  model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
120
+ visual_rwkv = RWKV(model=model_path, strategy='cuda fp16')
 
 
121
 
122
  ##########################################################################
123
  from modeling_vision import VisionEncoder, VisionEncoderConfig
 
131
  image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
132
  visual_encoder = visual_encoder.to(device)
133
  ##########################################################################
134
+ def visual_generate_prompt(instruction):
135
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
136
  return f"\n{instruction}\n\nAssistant:"
137
 
 
157
  for i in range(int(token_count)):
158
  if i == 0:
159
  input_ids = pipeline.encode(ctx)[-ctx_limit:]
160
+ out, state = visual_rwkv.forward(tokens=input_ids, state=image_state)
161
  else:
162
  input_ids = [token]
163
+ out, state = visual_rwkv.forward(tokens=input_ids, state=state)
164
  for n in occurrence:
165
  out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
166
 
 
193
 
194
  ##########################################################################
195
  cur_dir = os.path.dirname(os.path.abspath(__file__))
196
+ visual_examples = [
197
  [
198
  f"{cur_dir}/examples_pizza.jpg",
199
  "What are steps to cook it?"
 
221
  return base64_image
222
 
223
  image_cache = {}
 
 
224
  def compute_image_state(image):
225
  base64_image = pil_image_to_base64(image)
226
  if base64_image in image_cache:
227
  image_state = image_cache[base64_image]
228
  else:
229
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values'].to(device)
 
230
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
231
  # apply layer norm to image feature, very important
232
  image_features = F.layer_norm(image_features,
233
  (image_features.shape[-1],),
234
+ weight=model.w['blocks.0.ln0.weight'],
235
+ bias=model.w['blocks.0.ln0.bias'])
236
+ _, image_state = visual_rwkv.forward(embs=image_features, state=None)
237
  image_cache[base64_image] = image_state
238
  return image_state
239
 
 
242
  yield "Please upload an image."
243
  return
244
  image_state = compute_image_state(image)
245
+ input_text = visual_generate_prompt(question)
246
  for output in generate(input_text, image_state):
247
  yield output
248
 
249
+
250
+ ##################################################################################################################
251
  with gr.Blocks(title=title) as demo:
252
+ gr.HTML(f"<div style=\"text-align: center;\">\n<h1>RWKV-5 World v2 - {title}</h1>\n</div>")
253
+ with gr.Tab("Raw Generation"):
254
+ gr.Markdown(f"This is [RWKV-5 World v2](https://huggingface.co/BlinkDL/rwkv-5-world) with 1.5B params - a 100% attention-free RNN [RWKV-LM](https://github.com/BlinkDL/RWKV-LM). Supports all 100+ world languages and code. And we have [200+ Github RWKV projects](https://github.com/search?o=desc&p=1&q=rwkv&s=updated&type=Repositories). *** Please try examples first (bottom of page) *** (edit them to use your question). Demo limited to ctxlen {ctx_limit}.")
255
+ with gr.Row():
256
+ with gr.Column():
257
+ prompt = gr.Textbox(lines=2, label="Prompt", value="Assistant: Sure! Here is a very detailed plan to create flying pigs:")
258
+ token_count = gr.Slider(10, 333, label="Max Tokens", step=10, value=333)
259
+ temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
260
+ top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.3)
261
+ presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0)
262
+ count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=1)
263
+ with gr.Column():
264
+ with gr.Row():
265
+ submit = gr.Button("Submit", variant="primary")
266
+ clear = gr.Button("Clear", variant="secondary")
267
+ output = gr.Textbox(label="Output", lines=5)
268
+ data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Instructions", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
269
+ submit.click(evaluate, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
270
+ clear.click(lambda: None, [], [output])
271
+ data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
272
+ with gr.Tab("Visual RWKV"):
273
+ with gr.Row():
274
+ with gr.Column():
275
+ image = gr.Image(type='pil', label="Image")
276
+ with gr.Column():
277
+ prompt = gr.Textbox(lines=8, label="Prompt",
278
+ value="Render a clear and concise summary of the photo.")
279
+ with gr.Row():
280
+ submit = gr.Button("Submit", variant="primary")
281
+ clear = gr.Button("Clear", variant="secondary")
282
+ with gr.Column():
283
+ output = gr.Textbox(label="Output", lines=10)
284
+ data = gr.Dataset(components=[image, prompt], samples=visual_examples, label="Examples", headers=["Image", "Prompt"])
285
+ submit.click(chatbot, [image, prompt], [output])
286
+ clear.click(lambda: None, [], [output])
287
+ data.click(lambda x: x, [data], [image, prompt])
288
+
289
+ demo.queue(concurrency_count=1, max_size=10)
290
  demo.launch(share=False)