Spaces:
Running
Running
import torch as th | |
import numpy as np | |
def compute_logp(args, model, x, input_ids): | |
word_emb = model.weight | |
sigma = 0.1 | |
if args.model_arch == '1d-unet': | |
x = x.permute(0, 2, 1) | |
bsz, seqlen, dim = x.shape | |
x_flat = x.reshape(-1, x.size(-1)).unsqueeze(0) # 1, bsz*sample*seqlen, dim | |
word_emb_flat = word_emb.unsqueeze(1) # vocab, 1, dim | |
diff = (x_flat - word_emb_flat) ** 2 # vocab, seqlen, dim | |
logp_expanded = -diff.sum(dim=-1) / (2 * sigma ** 2) # vocab, seqlen | |
logp_expanded = logp_expanded.permute((1, 0)) | |
# print(th.topk(logp_expanded.view(bsz, seqlen, -1), k=5, dim=-1)[0]) | |
# print(input_ids[0]) | |
ce = th.nn.CrossEntropyLoss(reduction='none') | |
loss = ce(logp_expanded, input_ids.view(-1)).view(bsz, seqlen) | |
# print(loss[0]) | |
# print(loss.shape) | |
return loss | |
def get_weights(model, args): | |
if hasattr(model, 'transformer'): | |
input_embs = model.transformer.wte # input_embs | |
down_proj = model.down_proj | |
down_proj_emb = down_proj(input_embs.weight) | |
print(down_proj_emb.shape) | |
# model = th.nn.Embedding(down_proj_emb.shape[1], down_proj_emb.shape[0]) | |
model = th.nn.Embedding(down_proj_emb.size(0), down_proj_emb.size(1)) | |
print(args.emb_scale_factor) | |
model.weight.data = down_proj_emb * args.emb_scale_factor | |
elif hasattr(model, 'weight'): | |
pass | |
else: | |
assert NotImplementedError | |
model.weight.requires_grad = False | |
return model | |
def denoised_fn_round(args, model, text_emb, t): | |
# return text_emb | |
thresh_t = 350 | |
# print(thresh_t) | |
# print(t) | |
if thresh_t is not None and t[0] > thresh_t: | |
return text_emb | |
# return text_emb | |
# print(t.float().mean(), t[0]) | |
# assert t.float().mean() == t[0].float() | |
# print(text_emb.shape) # bsz, seqlen, dim | |
# down_proj_emb = model.weight # input_embs | |
down_proj_emb = model | |
# print(t) | |
old_shape = text_emb.shape | |
old_device = text_emb.device | |
def get_efficient_knn(down_proj_emb, text_emb, dist='l2'): | |
if dist == 'l2': | |
emb_norm = (down_proj_emb**2).sum(-1).view(-1, 1) #vocab | |
text_emb_t = th.transpose(text_emb.view(-1, text_emb.size(-1)), 0, 1) #d, bsz*seqlen | |
arr_norm = (text_emb ** 2).sum(-1).view(-1, 1) #bsz*seqlen, 1 | |
# print(emb_norm.shape, arr_norm.shape) | |
dist = emb_norm + arr_norm.transpose(0, 1) - 2.0 * th.mm(down_proj_emb, text_emb_t) #(vocab, d) x (d, bsz*seqlen) | |
dist = th.clamp(dist, 0.0, np.inf) | |
# print(dist.shape) | |
topk_out = th.topk(-dist, k=1, dim=0) | |
# adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
# down_proj_emb.size(0), -1, -1) | |
# adjacency = -th.norm(adjacency, dim=-1) | |
# topk_out = th.topk(adjacency, k=1, dim=0) | |
# print(topk_out1.indices == topk_out.indices) | |
# assert th.all(topk_out1.indices == topk_out.indices) | |
return topk_out.values, topk_out.indices | |
# def get_knn(down_proj_emb, text_emb, dist='l2'): | |
# if dist == 'l2': | |
# adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
# down_proj_emb.size(0), -1, -1) | |
# adjacency = -th.norm(adjacency, dim=-1) | |
# topk_out = th.topk(adjacency, k=1, dim=0) | |
# return topk_out.values, topk_out.indices | |
dist = 'l2' | |
if len(text_emb.shape) > 2: | |
text_emb = text_emb.reshape(-1, text_emb.size(-1)) | |
else: | |
text_emb = text_emb | |
# val, indices = get_knn(down_proj_emb, | |
# text_emb.to(down_proj_emb.device), dist=dist) | |
val, indices = get_efficient_knn(down_proj_emb, | |
text_emb.to(down_proj_emb.device), dist=dist) | |
rounded_tokens = indices[0] | |
# print(rounded_tokens.shape) | |
new_embeds = model[rounded_tokens].view(old_shape).to(old_device) | |
return new_embeds | |
def load_results(json_path, load_dict): | |
import json | |
with open(json_path, 'w') as f: | |
json.dump(load_dict, f, indent=2) | |