|
import json, time, random, os |
|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
class PIPELINE_ARGS(): |
|
def __init__(self, temperature=1.0, top_p=0.85, top_k=0, typical_p=1, alpha_frequency=0.2, alpha_presence=0.2, token_ban=[], token_stop=[], chunk_len=256): |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.top_k = top_k |
|
self.typical_p = typical_p |
|
self.alpha_frequency = alpha_frequency |
|
self.alpha_presence = alpha_presence |
|
self.token_ban = token_ban |
|
self.token_stop = token_stop |
|
self.chunk_len = chunk_len |
|
|
|
class PIPELINE(): |
|
def __init__(self, model, WORD_NAME): |
|
self.model = model |
|
if WORD_NAME == 'cl100k_base': |
|
import tiktoken |
|
self.tokenizer = tiktoken.get_encoding(WORD_NAME) |
|
else: |
|
from tokenizers import Tokenizer |
|
self.tokenizer = Tokenizer.from_file(WORD_NAME) |
|
|
|
def refine_context(self, context): |
|
context = context.strip().split('\n') |
|
for c in range(len(context)): |
|
context[c] = context[c].strip().strip('\u3000').strip('\r') |
|
context = list(filter(lambda c: c != '', context)) |
|
context = '\n' + ('\n'.join(context)).strip() |
|
if context == '': |
|
context = '\n' |
|
return context |
|
|
|
def encode(self, x): |
|
if 'tiktoken' in str(type(self.tokenizer)): |
|
return self.tokenizer.encode(x) |
|
else: |
|
return self.tokenizer.encode(x).ids |
|
|
|
def decode(self, x): |
|
return self.tokenizer.decode(x) |
|
|
|
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0,typical_p=1): |
|
probs = F.softmax(logits.float(), dim=-1) |
|
top_k = int(top_k) |
|
if typical_p<1: |
|
entropy = torch.nansum(-torch.log(probs) * probs, dim=-1, keepdim=True) |
|
typical_scores = torch.abs(logits - entropy) |
|
typical_sorted_ids = torch.argsort(typical_scores) |
|
sorted_typical_scores = typical_scores[typical_sorted_ids] |
|
typical_sorted_probs = probs[typical_sorted_ids] |
|
cum_typical_sorted_probs = torch.cumsum(typical_sorted_probs, dim=-1).cpu().numpy() |
|
typical_cutoff = float(sorted_typical_scores[np.argmax(cum_typical_sorted_probs > typical_p)]) |
|
if probs.device == torch.device('cpu'): |
|
probs = probs.numpy() |
|
sorted_ids = np.argsort(probs) |
|
sorted_probs = probs[sorted_ids][::-1] |
|
cumulative_probs = np.cumsum(sorted_probs) |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if top_k < len(probs) and top_k > 0: |
|
probs[sorted_ids[:-top_k]] = 0 |
|
if typical_p<1: |
|
probs[typical_scores > typical_cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs ** (1.0 / temperature) |
|
probs = probs / np.sum(probs) |
|
out = np.random.choice(a=len(probs), p=probs) |
|
return int(out) |
|
else: |
|
sorted_ids = torch.argsort(probs) |
|
sorted_probs = probs[sorted_ids] |
|
sorted_probs = torch.flip(sorted_probs, dims=(0,)) |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() |
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
|
probs[probs < cutoff] = 0 |
|
if top_k < len(probs) and top_k > 0: |
|
probs[sorted_ids[:-top_k]] = 0 |
|
if typical_p<1: |
|
probs[typical_scores > typical_cutoff] = 0 |
|
if temperature != 1.0: |
|
probs = probs ** (1.0 / temperature) |
|
out = torch.multinomial(probs, num_samples=1)[0] |
|
return int(out) |
|
|
|
def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None): |
|
all_tokens = [] |
|
out_last = 0 |
|
out_str = '' |
|
occurrence = {} |
|
for i in range(token_count): |
|
|
|
|
|
tokens = self.encode(ctx) if i == 0 else [token] |
|
while len(tokens) > 0: |
|
out, state = self.model.forward(tokens[:args.chunk_len], state) |
|
tokens = tokens[args.chunk_len:] |
|
|
|
for n in args.token_ban: |
|
out[n] = -float('inf') |
|
for n in occurrence: |
|
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) |
|
|
|
|
|
token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p) |
|
if token in args.token_stop: |
|
break |
|
all_tokens += [token] |
|
if token not in occurrence: |
|
occurrence[token] = 1 |
|
else: |
|
occurrence[token] += 1 |
|
|
|
|
|
tmp = self.decode(all_tokens[out_last:]) |
|
if '\ufffd' not in tmp: |
|
if callback: |
|
callback(tmp) |
|
out_str += tmp |
|
out_last = i + 1 |
|
return out_str |