|
# Usage |
|
|
|
# Model loading |
|
|
|
|
|
```python |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss |
|
from transformers import LlamaPreTrainedModel,LlamaModel,Gemma2PreTrainedModel,Gemma2Model,Cache |
|
from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
|
from typing import Optional, List, Union, Tuple |
|
|
|
@dataclass |
|
class Config: |
|
gemma_dir = '/kaggle/input/v7-dpo-16bit-01234-8bit-all/v7_dpo_16bit_01234_8bit_all' |
|
max_length = 2000 |
|
batch_size = 8 |
|
device = torch.device("cuda") if torch.cuda_is_available() else torch.device("cpu") |
|
|
|
cfg = Config() |
|
|
|
class Gemma2ForSequenceClassificationV1(Gemma2PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.model = Gemma2Model(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
# Initialize weights and apply final processing |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
transformer_outputs = self.model( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
# logits = self.score(hidden_states) |
|
|
|
if input_ids is not None: |
|
batch_size = input_ids.shape[0] |
|
else: |
|
batch_size = inputs_embeds.shape[0] |
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
if self.config.pad_token_id is None: |
|
sequence_lengths = -1 |
|
else: |
|
if input_ids is not None: |
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility |
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 |
|
sequence_lengths = sequence_lengths % input_ids.shape[-1] |
|
sequence_lengths = sequence_lengths.to(hidden_states.device) |
|
else: |
|
sequence_lengths = -1 |
|
hidden_states = hidden_states[ |
|
torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # eos |
|
pooled_logits = self.score(hidden_states) |
|
|
|
return pooled_logits |
|
|
|
|
|
tokenizer = GemmaTokenizerFast.from_pretrained(cfg.gemma_dir) |
|
|
|
model = Gemma2ForSequenceClassificationV1.from_pretrained( |
|
cfg.gemma_dir, |
|
num_labels=3, |
|
device_map=cfg.device, |
|
use_cache=False, |
|
) |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
``` |
|
|
|
|
|
|
|
|
|
# Inference |
|
```python |
|
def create_rounds(query: str, |
|
answer_a: str, |
|
answer_b: str) -> str: |
|
prompt =f"""User question: |
|
\"""{query}\""" |
|
Answer A: |
|
\"""{answer_a}\""" |
|
Answer B: |
|
\"""{answer_b}\""" |
|
""" |
|
return prompt |
|
|
|
|
|
@torch.no_grad() |
|
@torch.cuda.amp.autocast() |
|
def single_prompt_inference(prompt, model, device, max_length=cfg.max_length): |
|
""" |
|
Perform inference on a single prompt. |
|
|
|
Args: |
|
prompt (str): The input prompt for inference. |
|
model (torch.nn.Module): The model used for inference. |
|
device (torch.device): The device to run inference on. |
|
tokenizer (Tokenizer): Tokenizer for preprocessing input text. |
|
max_length (int): Maximum sequence length for tokenization. |
|
|
|
Returns: |
|
dict: Probabilities for "a_win", "b_win", and "tie". |
|
""" |
|
# Tokenize the input prompt |
|
input_ids = tokenizer(prompt, truncation=True, max_length=max_length)['input_ids'] |
|
input_ids.append(tokenizer.eos_token_id) |
|
|
|
# Prepare inputs |
|
inputs = pad_without_fast_tokenizer_warning( |
|
tokenizer, |
|
{"input_ids": [input_ids]}, # Wrap in a list for compatibility |
|
padding="max_length", |
|
pad_to_multiple_of=None, |
|
max_length=max_length, |
|
return_tensors="pt", |
|
) |
|
|
|
# Move inputs to the appropriate device |
|
inputs = inputs.to(cfg.device) |
|
|
|
# Run the model |
|
outputs = model(**inputs) |
|
|
|
# Get probabilities using softmax |
|
proba = outputs.softmax(-1).cpu().squeeze() |
|
|
|
return { |
|
"winner_model_a": proba[0].item(), |
|
"winner_model_b": proba[1].item(), |
|
"tie": proba[2].item(), |
|
} |
|
|
|
query = "What is the height of the reassembled blind product?" |
|
answer_a = "You can find all the technical information directly on the product sheet on our site." |
|
answer_b = "The height of the aluminum Venetian blind is 130 cm." |
|
prompt_direct = create_rounds(query, answer_a, answer_b) |
|
|
|
single_prompt_inference(prompt_direct, model, device) |
|
``` |
|
|
|
|
|
|
|
Credits to @sayoulala on kaggle for winnig the competition https://www.kaggle.com/competitions/lmsys-chatbot-arena and submitting this model. |