DaniilAlpha commited on
Commit
91c3b11
·
1 Parent(s): 05dc8f9

Update answerer.py

Browse files
Files changed (1) hide show
  1. answerer.py +27 -32
answerer.py CHANGED
@@ -5,19 +5,14 @@ from rwkv.model import RWKV
5
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
6
 
7
  class Answerer:
8
- def __init__(self, repo: str, filename: str, vocab: str, strategy: str, ctx_limit: int):
9
  os.environ["RWKV_JIT_ON"] = "1"
10
  # os.environ["RWKV_CUDA_ON"] = "1"
11
 
12
- self.__model = RWKV(hf_hub_download(repo, filename), strategy=strategy)
13
  self.__pipeline = PIPELINE(self.__model, vocab)
14
  self.ctx_limit = ctx_limit
15
 
16
- __model: RWKV
17
- __pipeline: PIPELINE
18
-
19
- ctx_limit: int
20
-
21
  def __call__(
22
  self,
23
  input: str,
@@ -45,35 +40,35 @@ class Answerer:
45
  current_token = None
46
  state = None
47
  for _ in range(max_output_length_tk):
48
- out, state = self.__model.forward(
49
- [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
50
- state,
51
- )
52
- for token in occurrences:
53
- out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency
54
 
55
- current_token = self.__pipeline.sample_logits(
56
- out,
57
- temperature=args.temperature,
58
- top_p=args.top_p,
59
- )
60
- if current_token in args.token_stop: break
61
 
62
- tokens.append(current_token)
63
 
64
- for token in occurrences:
65
- occurrences[token] *= 0.996
66
 
67
- if current_token in occurrences:
68
- occurrences[current_token] += 1
69
- else:
70
- occurrences[current_token] = 1
71
-
72
- tmp = self.__pipeline.decode(tokens)
73
- if "\ufffd" not in tmp:
74
- tokens.clear()
75
- result += tmp
76
- yield result.strip()
77
 
78
  tokens.clear()
79
  occurrences.clear()
 
5
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
6
 
7
  class Answerer:
8
+ def __init__(self, repo: str, model: str, vocab: str, strategy: str, ctx_limit: int):
9
  os.environ["RWKV_JIT_ON"] = "1"
10
  # os.environ["RWKV_CUDA_ON"] = "1"
11
 
12
+ self.__model = RWKV(hf_hub_download(repo, f"{model}.pth"), strategy=strategy)
13
  self.__pipeline = PIPELINE(self.__model, vocab)
14
  self.ctx_limit = ctx_limit
15
 
 
 
 
 
 
16
  def __call__(
17
  self,
18
  input: str,
 
40
  current_token = None
41
  state = None
42
  for _ in range(max_output_length_tk):
43
+ out, state = self.__model.forward(
44
+ [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
45
+ state,
46
+ )
47
+ for token in occurrences:
48
+ out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency
49
 
50
+ current_token = self.__pipeline.sample_logits(
51
+ out,
52
+ temperature=args.temperature,
53
+ top_p=args.top_p,
54
+ )
55
+ if current_token in args.token_stop: break
56
 
57
+ tokens.append(current_token)
58
 
59
+ for token in occurrences:
60
+ occurrences[token] *= 0.996
61
 
62
+ if current_token in occurrences:
63
+ occurrences[current_token] += 1
64
+ else:
65
+ occurrences[current_token] = 1
66
+
67
+ tmp = self.__pipeline.decode(tokens)
68
+ if "\ufffd" not in tmp:
69
+ tokens.clear()
70
+ result += tmp
71
+ yield result.strip()
72
 
73
  tokens.clear()
74
  occurrences.clear()