Spaces:
Paused
Paused
from typing import Dict, Generator, List | |
import os, gc | |
from huggingface_hub import hf_hub_download | |
from rwkv.model import RWKV | |
from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
### settings ### | |
### | |
os.environ["RWKV_JIT_ON"] = "1" | |
# os.environ["RWKV_CUDA_ON"] = "1" # if "1" then use CUDA kernel for seq mode (much faster) | |
class Answerer: | |
def __init__(self, repo: str, filename: str, vocab: str, strategy: str, ctx_limit: int): | |
os.environ["RWKV_JIT_ON"] = "1" | |
# os.environ["RWKV_CUDA_ON"] = "1" | |
self.__model = RWKV(hf_hub_download(repo, filename), strategy=strategy) | |
self.__pipeline = PIPELINE(self.__model, vocab) | |
self.ctx_limit = ctx_limit | |
__model: RWKV | |
__pipeline: PIPELINE | |
ctx_limit: int | |
def __call__( | |
self, | |
input: str, | |
max_output_length_tk: int, | |
chaos = .1, | |
repetitiveness = .3, | |
diversity = 0, | |
_count_penalty = 1, | |
) -> Generator[str, None, None]: | |
args = PIPELINE_ARGS( | |
temperature=chaos, | |
top_p=repetitiveness, | |
alpha_frequency=_count_penalty, | |
alpha_presence=diversity, | |
token_ban = [], | |
token_stop = [0], | |
) | |
input = input.strip() | |
result: str = "" | |
occurrences: Dict[int, int] = {} | |
tokens: List[int] = [] | |
current_token = None | |
state = None | |
for _ in range(max_output_length_tk): | |
out, state = self.__model.forward( | |
[current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:], | |
state, | |
) | |
for token in occurrences: | |
out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency | |
current_token = self.__pipeline.sample_logits( | |
out, | |
temperature=args.temperature, | |
top_p=args.top_p, | |
) | |
if current_token in args.token_stop: break | |
tokens.append(current_token) | |
for token in occurrences: | |
occurrences[token] *= 0.996 | |
if current_token in occurrences: | |
occurrences[current_token] += 1 | |
else: | |
occurrences[current_token] = 1 | |
tmp = self.__pipeline.decode(tokens) | |
if "\ufffd" not in tmp: | |
tokens.clear() | |
result += tmp | |
yield result.strip() | |
tokens.clear() | |
occurrences.clear() | |
del out, tmp | |
del occurrences, tokens, current_token, state | |
gc.collect() | |
yield result.strip() | |