|
import json, time, random, os |
|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
time_slot = {} |
|
time_ref = time.time_ns() |
|
|
|
def record_time(name): |
|
if name not in time_slot: |
|
time_slot[name] = 1e20 |
|
tt = (time.time_ns() - time_ref) / 1e9 |
|
if tt < time_slot[name]: |
|
time_slot[name] = tt |
|
|
|
class TOKENIZER(): |
|
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): |
|
if 'list' in str(type(WORD_NAME)): |
|
self.charMode = False |
|
if WORD_NAME[0] == WORD_NAME[1]: |
|
from transformers import PreTrainedTokenizerFast |
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) |
|
else: |
|
from transformers import GPT2TokenizerFast |
|
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) |
|
self.vocab_size = len(self.tokenizer) |
|
else: |
|
self.charMode = True |
|
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: |
|
self.word_table = json.load(result_file) |
|
|
|
self.vocab_size = len(self.word_table) |
|
|
|
self.stoi = {v: int(k) for k, v in self.word_table.items()} |
|
self.itos = {int(k): v for k, v in self.word_table.items()} |
|
|
|
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] |
|
|
|
def refine_context(self, context): |
|
context = context.strip().split('\n') |
|
for c in range(len(context)): |
|
context[c] = context[c].strip().strip('\u3000').strip('\r') |
|
context = list(filter(lambda c: c != '', context)) |
|
context = '\n' + ('\n'.join(context)).strip() |
|
if context == '': |
|
context = '\n' |
|
return context |
|
|
|
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): |
|
|
|
lastChar = int(x[-1]) |
|
|
|
probs = F.softmax(out, dim=-1) |
|
|
|
if self.charMode: |
|
if self.itos[lastChar] == '\n': |
|
top_p = top_p_newline |
|
else: |
|
top_p = top_p_usual |
|
else: |
|
top_p = top_p_usual |
|
|
|
if os.environ["RWKV_RUN_DEVICE"] == "cpu": |
|
probs = probs.numpy() |
|
sorted_probs = np.sort(probs)[::-1] |
|
cumulative_probs = np.cumsum(sorted_probs) |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs.pow(1.0 / temperature) |
|
probs = probs / np.sum(probs) |
|
out = np.random.choice(a=len(probs), p=probs) |
|
return out |
|
else: |
|
sorted_probs = torch.sort(probs, descending=True)[0] |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs.pow(1.0 / temperature) |
|
out = torch.multinomial(probs, num_samples=1)[0] |
|
return out |
|
|
|
def MaybeIsPrime(number): |
|
if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def FermatPrimalityTest(number): |
|
if number > 1: |
|
for time in range(3): |
|
randomNumber = random.randint(2, number) - 1 |
|
if pow(randomNumber, number - 1, number) != 1: |
|
return False |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def MillerRabinPrimalityTest(number): |
|
if number == 2: |
|
return True |
|
elif number == 1 or number % 2 == 0: |
|
return False |
|
oddPartOfNumber = number - 1 |
|
timesTwoDividNumber = 0 |
|
while oddPartOfNumber % 2 == 0: |
|
oddPartOfNumber = oddPartOfNumber // 2 |
|
timesTwoDividNumber = timesTwoDividNumber + 1 |
|
|
|
for time in range(3): |
|
while True: |
|
randomNumber = random.randint(2, number) - 1 |
|
if randomNumber != 0 and randomNumber != 1: |
|
break |
|
|
|
randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) |
|
|
|
if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): |
|
iterationNumber = 1 |
|
|
|
while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): |
|
randomNumberWithPower = pow(randomNumberWithPower, 2, number) |
|
iterationNumber = iterationNumber + 1 |
|
if randomNumberWithPower != (number - 1): |
|
return False |
|
|
|
return True |
|
|