File size: 8,730 Bytes
6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 0b5f921 6a00905 714d682 6a00905 714d682 6a00905 0b5f921 6a00905 0b5f921 6a00905 714d682 6a00905 714d682 6a00905 714d682 6a00905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from .configuration_keeper import KeeperConfig
import torch
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModel,
PreTrainedModel,
PretrainedConfig,
AutoModelForCausalLM,
BitsAndBytesConfig
)
from typing import Dict
import torch
import numpy as np
from einops import rearrange
class KeeperModelForCausalLM(PreTrainedModel):
"""
ColBERT model from: https://arxiv.org/pdf/2004.12832.pdf
We use a dot-product instead of cosine per term (slightly better)
"""
config_class = KeeperConfig
base_model_prefix = "keeper_model"
def __init__(self, cfg, n_cands=8, update_both=False) -> None:
super().__init__(cfg)
self.bert = None
self.llm = None
if cfg:
print("Initializing KeeperModelForCausalLM from cfg")
# Inicialización con configuración
self.bert = AutoModel.from_pretrained(cfg.retriever_config['_name_or_path'])
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
self.llm = AutoModelForCausalLM.from_pretrained(
cfg.model_config['_name_or_path'],
device_map=cfg.device_map,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config
)
# Almacena kwargs para la serialización y carga futura
# self.init_kwargs = {'cfg': cfg}
print("Initialization complete")
else:
# Si cfg no se proporciona, esto se manejará en el método from_pretrained
print("Initializing KeeperTokenizer without cfg")
self.n_cands = n_cands
self.update_both = update_both
print(f"Model n_cands: {self.n_cands}")
def _load_from_state_dict(self, state_dict, *args, **kwargs):
super()._load_from_state_dict(state_dict, *args, **kwargs)
# Ensure CUDA is available
if torch.cuda.is_available():
device = torch.device('cuda')
if "document_retriever_text" in state_dict:
self.document_retriever_text = state_dict["document_retriever_text"].to(device)
if "document_retriever_mask" in state_dict:
self.document_retriever_mask = state_dict["document_retriever_mask"].to(device)
if "document_retriever_type" in state_dict:
self.document_retriever_type = state_dict["document_retriever_type"].to(device)
if "document_model_text" in state_dict:
self.document_model_text = state_dict["document_model_text"].to(device)
if "prompt_left" in state_dict:
self.prompt_left = state_dict["prompt_left"].to(device)
if "prompt_right" in state_dict:
self.prompt_right = state_dict["prompt_right"].to(device)
if "respuesta" in state_dict:
self.respuesta = state_dict["respuesta"].to(device)
else:
# Optionally handle the case where CUDA is not available
print("CUDA is not available. Tensors will remain on CPU.")
def generate(self, query: Dict[str, torch.LongTensor], k: int = 3, max_new_tokens=256, repetition_penalty=1.15, temperature=0.1, do_sample=True, **kwargs):
query_model = {k: v.to("cuda") for k, v in query['tokens_model'].items()}
topk_texts = self.document_extractor(query, k)
concatenated_texts = torch.cat(topk_texts, dim=0)
T = torch.cat((self.prompt_left, concatenated_texts.unsqueeze(0), self.prompt_right, query_model['input_ids'], self.respuesta), dim=1)
prompt_length = T.shape[1]
outputs = self.llm.generate(input_ids=T,max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty, temperature=temperature, do_sample=do_sample)
return outputs[0][prompt_length:].unsqueeze(0)
def document_extractor(self, query: Dict[str, torch.LongTensor], k_val: int = 3, **kwargs):
query_retriever = {k: v.to("cuda") for k, v in query['tokens_retriever'].items()}
query_vecs = self.forward_representation(query_retriever)
doc_dic = {'input_ids': self.document_retriever_text, 'attention_mask':self.document_retriever_mask, 'token_type_ids': self.document_retriever_type}
document_vecs = self.forward_representation(doc_dic, sequence_type="doc")
self.score = self.forward_aggregation(query_vecs, query['tokens_retriever']["attention_mask"], document_vecs, self.document_retriever_mask)
k_val = min(k_val, self.score.numel())
topk_scores, topk_indices = torch.topk(self.score, k_val)
return [self.document_model_text[i,:] for i in topk_indices[0].tolist()]
def forward_representation(self,
tokens,
max_seq_len = 128,
sequence_type=None) -> torch.Tensor:
if sequence_type == "doc":
if self.update_both:
with torch.no_grad():
vecs = self.bert(**tokens)[0]
else:
with torch.no_grad():
with torch.no_grad():
vecs = self.bert(**tokens)[0] # assuming a distilbert model here
else:
with torch.no_grad():
vecs = self.bert(**tokens)[0]
# vecs = self.compressor(vecs)
return vecs
def forward_aggregation(self, query_vecs, query_mask, document_vecs, document_mask):
# query_vecs: B x N x D
# doc_vecs: (B * k) x N x D
# Unsqueeze query vector
_bsz = query_vecs.shape[0]
n_cands = document_vecs.shape[0] // _bsz
query_vecs_dup = query_vecs.repeat_interleave(n_cands, dim=0).contiguous()
score = torch.bmm(query_vecs_dup, document_vecs.transpose(1, 2))
exp_mask = document_mask.bool().unsqueeze(1).expand(-1, score.shape[1], -1)
score[~exp_mask] = - 10000
# max pooling over document dimension
score = score.max(-1).values
query_mask_dup = query_mask.repeat_interleave(n_cands, dim=0).contiguous()
score[~(query_mask_dup.bool())] = 0
score = rearrange(score.sum(-1), '(b n) -> b n', n=n_cands) # B x k
return score
def prompt(self, left_p = None, right_p = None):
if left_p is None:
left_p = """ <bos><start_of_turn>user
Eres un experto en cultura paraguaya que responde de forma clara, amable y concisa.
Segun el siguiente contexto:
-------------------------------
"""
if right_p is None:
right_p = """
-------------------------------
- Solamente puedes responder usando el contexto de arriba, si no se encuentra en el contexto mencionar: 'No tengo informacion sobre eso'.
- Si encuentras la respuesta puedes copiarla.
- Debes responder solamente en Espanol.
Pregunta: """
return left_p, right_p
def save_docs(self, docs: list, tokenizer, max_seq_len=128):
# Tokenizamos el prompt
prompt_left, prompt_right = self.prompt()
prompt_left_output = tokenizer.encode(prompt_left)
prompt_right_output = tokenizer.encode(prompt_right)
# Tokenizamos el documento
doc_outputs = tokenizer.encode(docs, max_length=max_seq_len, padding='max_length', truncation=True)
# Pasamos los tensores a cuda (## optimizar: se guardan tensores que no se utilizaran en la gpu)
doc_outputs = {k: v.to("cuda") for k, v in doc_outputs.items()}
prompt_left_output = {k: v.to("cuda") for k, v in prompt_left_output.items()}
prompt_right_output = {k: v.to("cuda") for k, v in prompt_right_output.items()}
# Tokenizamos la Respuesta
resp = tokenizer.encode("""
Respuesta: <end_of_turn>
<start_of_turn>model """)
resp_model = {k: v.to("cuda") for k, v in resp['tokens_model'].items()}
# Actualizar el buffer con los vectores de documentos
self.document_retriever_text = doc_outputs['tokens_retriever']['input_ids']
self.document_retriever_mask = doc_outputs['tokens_retriever']['attention_mask']
self.document_retriever_type = doc_outputs['tokens_retriever']['token_type_ids']
self.document_model_text = doc_outputs['tokens_model']['input_ids']
# self.document_model_mask = key_outputs['tokens_model']['attention_mask']
# self.document_model_type = key_outputs['tokens_model']['token_type_ids']
self.prompt_left = prompt_left_output['tokens_model']['input_ids']
self.prompt_right = prompt_right_output['tokens_model']['input_ids']
self.respuesta = resp_model['input_ids'] |