Tonic commited on
Commit
8c00703
·
verified ·
1 Parent(s): fb15fc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -49
app.py CHANGED
@@ -7,9 +7,6 @@ from copy import deepcopy
7
  import requests
8
  import os.path
9
  from tqdm import tqdm
10
- import json
11
- from dataclasses import dataclass
12
- from typing import Optional, List
13
 
14
  # Set environment variables
15
  os.environ['RWKV_JIT_ON'] = '1'
@@ -22,28 +19,9 @@ MODELS = {
22
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
23
  }
24
 
25
- # Model configurations
26
- MODEL_CONFIGS = {
27
- "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth": {
28
- "n_layer": 12,
29
- "n_embd": 768,
30
- "ctx_len": 4096
31
- },
32
- "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth": {
33
- "n_layer": 24,
34
- "n_embd": 1024,
35
- "ctx_len": 4096
36
- }
37
- }
38
-
39
- @dataclass
40
- class ModelArgs:
41
- n_layer: int
42
- n_embd: int
43
- ctx_len: int
44
- vocab_size: int = 65536
45
- n_head: int = 16 # Number of attention heads
46
- n_att: int = 1024 # Attention dimension
47
 
48
  def download_file(url, filename):
49
  """Generic file downloader with progress bar"""
@@ -69,44 +47,28 @@ def download_model(model_name):
69
  url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
70
  download_file(url, model_name)
71
 
72
- class CustomPipeline(PIPELINE):
73
- def __init__(self, model, vocab_file):
74
- super().__init__(model, vocab_file)
75
- self.model_args = None
76
-
77
- def set_model_args(self, args: ModelArgs):
78
- self.model_args = args
79
 
80
  class ModelManager:
81
  def __init__(self):
82
  self.current_model = None
83
  self.current_model_name = None
84
  self.pipeline = None
 
85
 
86
  def load_model(self, model_choice):
87
  model_file = MODELS[model_choice]
88
  if model_file != self.current_model_name:
89
  download_model(model_file)
90
-
91
- # Get model configuration
92
- config = MODEL_CONFIGS[model_file]
93
- model_args = ModelArgs(
94
- n_layer=config['n_layer'],
95
- n_embd=config['n_embd'],
96
- ctx_len=config['ctx_len']
97
- )
98
-
99
- # Initialize model with args
100
  self.current_model = RWKV(
101
  model=model_file,
102
  strategy='cpu fp32'
103
  )
104
-
105
- # Initialize custom pipeline
106
- self.pipeline = CustomPipeline(self.current_model, "20B_tokenizer.json")
107
- self.pipeline.set_model_args(model_args)
108
  self.current_model_name = model_file
109
-
110
  return self.pipeline
111
 
112
  model_manager = ModelManager()
@@ -143,8 +105,7 @@ def generate_response(
143
  alpha_decay=alpha_decay,
144
  token_ban=[],
145
  token_stop=[],
146
- chunk_len=256,
147
- model_args=pipeline.model_args # Pass model args to pipeline
148
  )
149
 
150
  # Generate response
 
7
  import requests
8
  import os.path
9
  from tqdm import tqdm
 
 
 
10
 
11
  # Set environment variables
12
  os.environ['RWKV_JIT_ON'] = '1'
 
19
  "0.4B (Larger)": "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
20
  }
21
 
22
+ # Download vocab file if not present
23
+ VOCAB_FILE = "rwkv_vocab_v20230424.txt"
24
+ VOCAB_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/v2/rwkv_vocab_v20230424.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def download_file(url, filename):
27
  """Generic file downloader with progress bar"""
 
47
  url = f"https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/{model_name}"
48
  download_file(url, model_name)
49
 
50
+ def ensure_vocab():
51
+ """Ensure vocab file is present"""
52
+ if not os.path.exists(VOCAB_FILE):
53
+ download_file(VOCAB_URL, VOCAB_FILE)
 
 
 
54
 
55
  class ModelManager:
56
  def __init__(self):
57
  self.current_model = None
58
  self.current_model_name = None
59
  self.pipeline = None
60
+ ensure_vocab()
61
 
62
  def load_model(self, model_choice):
63
  model_file = MODELS[model_choice]
64
  if model_file != self.current_model_name:
65
  download_model(model_file)
 
 
 
 
 
 
 
 
 
 
66
  self.current_model = RWKV(
67
  model=model_file,
68
  strategy='cpu fp32'
69
  )
70
+ self.pipeline = PIPELINE(self.current_model, VOCAB_FILE)
 
 
 
71
  self.current_model_name = model_file
 
72
  return self.pipeline
73
 
74
  model_manager = ModelManager()
 
105
  alpha_decay=alpha_decay,
106
  token_ban=[],
107
  token_stop=[],
108
+ chunk_len=256
 
109
  )
110
 
111
  # Generate response