Model Details
Model Description
This LoRA is trained for the Query Clarification Task by Revolut.
About QC task
In this task, the input is a raw, unstructured text query. The goal is to analyze this query and rephrase it into a clear and concise question. This rephrased question should be relevant to Revolut's services or context.
query | clarified_query |
---|---|
i pay | Unclear Query |
teen | What is a Revolut Junior account? |
move abroad | How do I manage my Revolut account when moving to another country? |
About LoRA checkpoint
- train loss: 0.3562
- validation loss: 0.3696512281894684
- device: A100 80GB
Inference
Main functions
from typing import Any
import torch
from peft import PeftModel
from pydantic import BaseSettings
from tqdm.auto import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
)
class TextGenerationParams(BaseSettings):
max_new_tokens: int = 64
temperature: float = 0.1
top_p: float = 0.1
do_sample: bool = True
def generate_response(
text_input: str | list[str],
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
batch_size: int = 1,
verbose: bool = True,
**generation_arguments: Any
) -> list[str]:
if isinstance(text_input, str):
text_input = [text_input]
dataloader = torch.utils.data.DataLoader(text_input, batch_size=batch_size)
device = model.device
generated_responses = []
with torch.cuda.amp.autocast():
if verbose:
iter_data = tqdm(dataloader, desc="Text generation")
else:
iter_data = dataloader
for batch in iter_data:
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids.to(device)
outputs = model.generate(input_ids=input_ids, **generation_arguments)
generated_ids = outputs[:, input_ids.shape[1] :]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
generated_responses.extend(response)
return generated_responses
Let's run inference
# Load Models
>>> base_model_name = "unsloth/Meta-Llama-3.1-8B"
>>> peft_model = "VitalyProtasov/Revolut-Query-Clarification-LoRA-LLAMA-8b"
>>> base_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map="auto")
>>> model = PeftModel.from_pretrained(base_model, peft_model, is_trainable=True)
>>> tokenizer = AutoTokenizer.from_pretrained(base_model_name)
>>> model.print_trainable_parameters()
>>> inference_args = TextGenerationParams()
>>> QUERY_CLARIFICATION_PROMPT = """
### Raw query:
{raw_query}
### Clarified query:
{clarified_query}
"""
>>> raw_inputs = ["i pay", "can i have more than 1 under 18 acc"]
>>> formatted_inputs = [
QUERY_CLARIFICATION_PROMPT.format(raw_query=i, clarified_query="") for i in raw_inputs
]
>>> generated_texts = generate_response(
text_input=formatted_inputs,
model=model,
tokenizer=tokenizer,
batch_size=2,
**inference_args.dict(),
)
>>> print(generated_texts)
['Unclear Query\n', 'Can I have multiple Junior accounts with Revolut?\n']
Packages
[tool.poetry.dependencies]
python = "^3.11"
accelerate = "^1.0.1"
bitsandbytes = "^0.44.1"
datasets = "^2.16.1"
peft = "^0.13.2"
torch = "2.2.1"
transformers = "^4.45.2"
trl = "^0.11.4"
- Downloads last month
- 0
Model tree for VitalyProtasov/Revolut-Query-Clarification-LoRA-LLAMA-8b
Base model
unsloth/Meta-Llama-3.1-8B