Spaces:
Build error
Build error
# ------------------------------------------------------------------------------------ | |
# Minimal DALL-E | |
# Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------------------ | |
import torch | |
from typing import Optional | |
from tqdm import tqdm | |
from torch.nn import functional as F | |
torch.set_printoptions(precision=2, threshold=10) | |
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor: | |
if k is None: | |
return logits | |
else: | |
v, ix = torch.topk(logits, k) | |
out = logits.clone() | |
out[out < v[:, [-1]]] = -float('Inf') | |
return out | |
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor: | |
if p is None: | |
return probs | |
else: | |
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) | |
cum_probs = torch.cumsum(sorted_probs, dim=-1) | |
sorted_idx_remove_cond = cum_probs >= p | |
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone() | |
sorted_idx_remove_cond[..., 0] = 0 | |
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond) | |
probs = probs.masked_fill(indices_to_remove, 0.0) | |
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True) | |
return norm_probs | |
def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor: | |
device = inputs.device | |
if mode == '1d': | |
B, N = inputs.shape | |
xs_pos = torch.arange(N, device=device).repeat((B, 1)) | |
elif mode == '2d': | |
B, H, W = inputs.shape | |
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2) | |
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1) | |
xs_pos = (xs_pos_h, xs_pos_w) | |
else: | |
raise ValueError('%s positional encoding invalid' % mode) | |
return xs_pos | |
def sampling(model: torch.nn.Module, | |
tokens: torch.LongTensor, | |
top_k: Optional[float] = None, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
is_tqdm: bool = True, | |
use_fp16: bool = True, | |
max_seq_len: int = 256, | |
prompt: Optional[torch.tensor] = None, | |
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor: | |
code = None | |
past = None | |
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) | |
pos_enc_tokens = get_positional_encoding(tokens, mode='1d') | |
for cnt, h in enumerate(pbar): | |
if code is None: | |
code_ = None | |
pos_enc_code_ = None | |
else: | |
code_ = code.clone().detach() | |
pos_enc_code_ = get_positional_encoding(code_, mode='1d') | |
code_ = code_[:, cnt-1].unsqueeze(-1) | |
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) | |
logits, present = model.sampling(images=code_, | |
texts=tokens, | |
pos_images=pos_enc_code_, | |
pos_texts=pos_enc_tokens, | |
use_fp16=use_fp16, | |
past=past, | |
prompt=prompt, | |
pos_prompt=pos_prompt) | |
logits = logits.to(dtype=torch.float32) | |
logits = logits / softmax_temperature | |
# print(len(present), present[0].shape) | |
present = torch.stack(present).clone().detach() | |
if past is None: | |
past = [present] | |
else: | |
past.append(present) | |
logits = cutoff_topk_logits(logits, top_k) | |
probs = F.softmax(logits, dim=-1) | |
probs = cutoff_topp_probs(probs, top_p) | |
# print(probs[0]) | |
idx = torch.multinomial(probs, num_samples=1).clone().detach() | |
# print(idx) | |
code = idx if code is None else torch.cat([code, idx], axis=1) | |
del past | |
return code | |
def sampling_prefix(model: torch.nn.Module, | |
tokens: torch.LongTensor, | |
past: torch.FloatTensor, | |
top_k: Optional[float] = None, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
is_tqdm: bool = True, | |
use_fp16: bool = True, | |
max_seq_len: int = 256, | |
labels = None) -> torch.LongTensor: | |
code = None | |
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) | |
pos_enc_tokens = get_positional_encoding(tokens, mode='1d') | |
# print("Entering sampling_prefix; ", past.shape) | |
if past is not None: | |
past = [past] | |
for cnt, h in enumerate(pbar): | |
if code is None: | |
code_ = None | |
pos_enc_code_ = None | |
else: | |
code_ = code.clone().detach() | |
pos_enc_code_ = get_positional_encoding(code_, mode='1d') | |
code_ = code_[:, cnt-1].unsqueeze(-1) | |
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) | |
# print("Looop enter") | |
# print(cnt, past[0].shape) | |
# print("-------------------") | |
logits, present = model.sampling(images=code_, | |
texts=tokens, | |
pos_images=pos_enc_code_, | |
pos_texts=pos_enc_tokens, | |
use_fp16=use_fp16, | |
past=past) | |
logits = logits.to(dtype=torch.float32) | |
logits = logits / softmax_temperature | |
present = torch.stack(present).clone().detach() | |
# print('Present', present.shape) | |
if past is None: | |
past = [present] | |
else: | |
# print("Loop end") | |
# print(present.shape) | |
# print("-----------------") | |
# n_layers, temp, _, seq_len, n_dim = present.shape | |
# _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape | |
# assert temp == 2 | |
# past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim)) | |
past.append(present) | |
logits = cutoff_topk_logits(logits, top_k) | |
probs = F.softmax(logits, dim=-1) | |
probs = cutoff_topp_probs(probs, top_p) | |
print(torch.topk(probs, 5, dim=-1)) | |
if labels is not None: | |
print(labels[cnt]) | |
idx = torch.multinomial(probs, num_samples=1).clone().detach() | |
# print(idx) | |
code = idx if code is None else torch.cat([code, idx], axis=1) | |
del past | |
return code | |
def sampling_prefix_new(model: torch.nn.Module, | |
tokens: torch.LongTensor, | |
past: torch.FloatTensor, | |
top_k: Optional[float] = None, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
is_tqdm: bool = True, | |
use_fp16: bool = True, | |
max_seq_len: int = 256) -> torch.LongTensor: | |
code = None | |
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) | |
pos_enc_tokens = get_positional_encoding(tokens, mode='1d') | |
# print("Entering sampling_prefix; ", past.shape) | |
if past is not None: | |
past = [past] | |
for cnt, h in enumerate(pbar): | |
if code is None: | |
code_ = None | |
pos_enc_code_ = None | |
else: | |
code_ = code.clone().detach() | |
pos_enc_code_ = get_positional_encoding(code_, mode='1d') | |
# code_ = code_[:, cnt-1].unsqueeze(-1) | |
# pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) | |
# print("Looop enter") | |
# print(cnt, past[0].shape) | |
# print("-------------------") | |
if cnt == 0: | |
logits, present = model.sampling(images=code_, | |
texts=tokens, | |
pos_images=pos_enc_code_, | |
pos_texts=pos_enc_tokens, | |
use_fp16=use_fp16, | |
past=past) | |
logits = logits.to(dtype=torch.float32) | |
logits = logits / softmax_temperature | |
present = torch.stack(present).clone().detach() | |
# print('Present', present.shape) | |
if past is None: | |
past = [present] | |
else: | |
pass | |
logits = cutoff_topk_logits(logits, top_k) | |
probs = F.softmax(logits, dim=-1) | |
probs = cutoff_topp_probs(probs, top_p) | |
# print(torch.topk(probs[0], 5)) | |
idx = torch.multinomial(probs, num_samples=1).clone().detach() | |
# print(idx) | |
code = idx if code is None else torch.cat([code, idx], axis=1) | |
else: | |
pass | |
del past | |
return code | |
def sampling_conditional(model: torch.nn.Module, | |
cross_attention_idxs, | |
cross_attention_layers, | |
tokens: torch.LongTensor, | |
src_codes: torch.FloatTensor, | |
top_k: Optional[float] = None, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
is_tqdm: bool = True, | |
use_fp16: bool = True, | |
max_seq_len: int = 256, | |
prompt: Optional[torch.tensor] = None, | |
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor: | |
code = None | |
past = None | |
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) | |
pos_enc_tokens = get_positional_encoding(tokens, mode='1d') | |
src_pos_tokens = get_positional_encoding(src_codes, mode='1d') | |
src_tokens = model.tok_emb_img(src_codes) | |
src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens) | |
for cnt, h in enumerate(pbar): | |
if code is None: | |
code_ = None | |
pos_enc_code_ = None | |
else: | |
code_ = code.clone().detach() | |
pos_enc_code_ = get_positional_encoding(code_, mode='1d') | |
code_ = code_[:, cnt-1].unsqueeze(-1) | |
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) | |
logits, present = model.sampling_with_context(images=code_, | |
cross_attention_idxs=cross_attention_idxs, | |
cross_attention_layers=cross_attention_layers, | |
texts=tokens, | |
pos_images=pos_enc_code_, | |
pos_texts=pos_enc_tokens, | |
source_image=src_tokens, | |
use_fp16=use_fp16, | |
past=past, | |
prompt=prompt, | |
pos_prompt=pos_prompt) | |
logits = logits.to(dtype=torch.float32) | |
logits = logits / softmax_temperature | |
present = torch.stack(present).clone().detach() | |
if past is None: | |
past = [present] | |
else: | |
past.append(present) | |
logits = cutoff_topk_logits(logits, top_k) | |
probs = F.softmax(logits, dim=-1) | |
probs = cutoff_topp_probs(probs, top_p) | |
idx = torch.multinomial(probs, num_samples=1).clone().detach() | |
code = idx if code is None else torch.cat([code, idx], axis=1) | |
del past | |
return code | |
def sampling_igpt(model: torch.nn.Module, | |
sos: torch.FloatTensor, | |
top_k: Optional[float] = None, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
is_tqdm: bool = True, | |
use_fp16: bool = True, | |
max_seq_len: int = 256) -> torch.LongTensor: | |
code = None | |
past = None | |
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len) | |
for cnt, h in enumerate(pbar): | |
if code is None: | |
code_ = None | |
pos_enc_code_ = None | |
else: | |
code_ = code.clone().detach() | |
pos_enc_code_ = get_positional_encoding(code_, mode='1d') | |
code_ = code_[:, cnt-1].unsqueeze(-1) | |
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1) | |
logits, present = model.sampling(sos=sos, | |
codes=code_, | |
pos_codes=pos_enc_code_, | |
use_fp16=use_fp16, | |
past=past) | |
logits = logits.to(dtype=torch.float32) | |
logits = logits / softmax_temperature | |
present = torch.stack(present).clone().detach() | |
if past is None: | |
past = [present] | |
else: | |
past.append(present) | |
logits = cutoff_topk_logits(logits, top_k) | |
probs = F.softmax(logits, dim=-1) | |
probs = cutoff_topp_probs(probs, top_p) | |
idx = torch.multinomial(probs, num_samples=1).clone().detach() | |
code = idx if code is None else torch.cat([code, idx], axis=1) | |
del past | |
return code | |