Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This software may be used and distributed according to the terms of the GNU General Public License version 3. | |
from typing import List | |
import torch | |
from llama.tokenizer import Tokenizer | |
from llama.model import Transformer | |
class LLaMA: | |
def __init__(self, model: Transformer, tokenizer: Tokenizer, vision_model = None): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.vision_model = vision_model | |
def generate( | |
self, | |
prompts: List[str], | |
imgs = None, | |
max_gen_len: int = 512, | |
temperature: float = 0.8, | |
top_p: float = 0.95, | |
) -> List[str]: | |
bsz = len(prompts) | |
params = self.model.params | |
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) | |
mode = 'instruct' | |
vision_tokens = None | |
if imgs is not None and self.vision_model is not None: | |
vision_tokens = self.vision_model(imgs) | |
mode = 'caption' | |
prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] | |
min_prompt_size = min([len(t) for t in prompt_tokens]) | |
max_prompt_size = max([len(t) for t in prompt_tokens]) | |
total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) | |
tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() | |
for k, t in enumerate(prompt_tokens): | |
tokens[k, : len(t)] = torch.tensor(t).long() | |
input_text_mask = tokens != self.tokenizer.pad_id | |
start_pos = min_prompt_size | |
prev_pos = 0 | |
for cur_pos in range(start_pos, total_len): | |
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos, vision_tokens, mode) | |
if temperature > 0: | |
probs = torch.softmax(logits / temperature, dim=-1) | |
next_token = sample_top_p(probs, top_p) | |
else: | |
next_token = torch.argmax(logits, dim=-1) | |
next_token = next_token.reshape(-1) | |
# only replace token if prompt has already been generated | |
next_token = torch.where( | |
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token | |
) | |
tokens[:, cur_pos] = next_token | |
prev_pos = cur_pos | |
decoded = [] | |
for i, t in enumerate(tokens.tolist()): | |
# cut to max gen len | |
t = t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len] | |
# cut to eos tok if any | |
try: | |
t = t[: t.index(self.tokenizer.eos_id)] | |
except ValueError: | |
pass | |
decoded.append(self.tokenizer.decode(t)) | |
return decoded | |
def sample_top_p(probs, p): | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |