Spaces:
Paused
Paused
File size: 2,405 Bytes
5ee61de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from typing import Dict, Generator, List
import os, gc
from huggingface_hub import hf_hub_download
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
### settings ###
###
os.environ["RWKV_JIT_ON"] = "1"
# os.environ["RWKV_CUDA_ON"] = "1" # if "1" then use CUDA kernel for seq mode (much faster)
class Answerer:
def __init__(self, repo: str, filename: str, vocab: str, strategy: str, ctx_limit: int):
os.environ["RWKV_JIT_ON"] = "1"
# os.environ["RWKV_CUDA_ON"] = "1"
self.__model = RWKV(hf_hub_download(repo, filename), strategy=strategy)
self.__pipeline = PIPELINE(self.__model, vocab)
self.ctx_limit = ctx_limit
__model: RWKV
__pipeline: PIPELINE
ctx_limit: int
def __call__(
self,
input: str,
max_output_length_tk: int,
chaos = .1,
repetitiveness = .3,
diversity = 0,
_count_penalty = 1,
) -> Generator[str, None, None]:
args = PIPELINE_ARGS(
temperature=chaos,
top_p=repetitiveness,
alpha_frequency=_count_penalty,
alpha_presence=diversity,
token_ban = [],
token_stop = [0],
)
input = input.strip()
result: str = ""
occurrences: Dict[int, int] = {}
tokens: List[int] = []
current_token = None
state = None
for _ in range(max_output_length_tk):
out, state = self.__model.forward(
[current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:],
state,
)
for token in occurrences:
out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency
current_token = self.__pipeline.sample_logits(
out,
temperature=args.temperature,
top_p=args.top_p,
)
if current_token in args.token_stop: break
tokens.append(current_token)
for token in occurrences:
occurrences[token] *= 0.996
if current_token in occurrences:
occurrences[current_token] += 1
else:
occurrences[current_token] = 1
tmp = self.__pipeline.decode(tokens)
if "\ufffd" not in tmp:
tokens.clear()
result += tmp
yield result.strip()
tokens.clear()
occurrences.clear()
del out, tmp
del occurrences, tokens, current_token, state
gc.collect()
yield result.strip()
|