Edit model card

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

Model Details

Model Description

Revolut Logo

This LoRA is trained for the Query Clarification Task by Revolut.

About QC task

In this task, the input is a raw, unstructured text query. The goal is to analyze this query and rephrase it into a clear and concise question. This rephrased question should be relevant to Revolut's services or context.

query clarified_query
i pay Unclear Query
teen What is a Revolut Junior account?
move abroad How do I manage my Revolut account when moving to another country?

About LoRA checkpoint

  • train loss: 0.3562
  • validation loss: 0.3696512281894684
  • device: A100 80GB

Inference

Main functions
from typing import Any

import torch
from peft import PeftModel
from pydantic import BaseSettings
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)


class TextGenerationParams(BaseSettings):
    max_new_tokens: int = 64
    temperature: float = 0.1
    top_p: float = 0.1
    do_sample: bool = True


def generate_response(
    text_input: str | list[str],
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    batch_size: int = 1,
    verbose: bool = True,
    **generation_arguments: Any
) -> list[str]:
    if isinstance(text_input, str):
        text_input = [text_input]

    dataloader = torch.utils.data.DataLoader(text_input, batch_size=batch_size)
    device = model.device

    generated_responses = []
    with torch.cuda.amp.autocast():
        if verbose:
            iter_data = tqdm(dataloader, desc="Text generation")
        else:
            iter_data = dataloader
        for batch in iter_data:
            input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids.to(device)
            outputs = model.generate(input_ids=input_ids, **generation_arguments)
            generated_ids = outputs[:, input_ids.shape[1] :]
            response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            generated_responses.extend(response)

    return generated_responses
Let's run inference
# Load Models
>>> base_model_name = "unsloth/Meta-Llama-3.1-8B"
>>> peft_model = "VitalyProtasov/Revolut-Query-Clarification-LoRA-LLAMA-8b"

>>> base_model = AutoModelForCausalLM.from_pretrained(base_model_name, device_map="auto")
>>> model = PeftModel.from_pretrained(base_model, peft_model, is_trainable=True)
>>> tokenizer = AutoTokenizer.from_pretrained(base_model_name)

>>> model.print_trainable_parameters()

>>> inference_args = TextGenerationParams()

>>> QUERY_CLARIFICATION_PROMPT = """
    ### Raw query:
    {raw_query}
    
    ### Clarified query:
    {clarified_query}
    """
>>> raw_inputs = ["i pay", "can i have more than 1 under 18 acc"]
>>> formatted_inputs = [
    QUERY_CLARIFICATION_PROMPT.format(raw_query=i, clarified_query="") for i in raw_inputs
    ]

>>> generated_texts = generate_response(
    text_input=formatted_inputs,
    model=model,
    tokenizer=tokenizer,
    batch_size=2,
    **inference_args.dict(),
    )

>>> print(generated_texts)
['Unclear Query\n', 'Can I have multiple Junior accounts with Revolut?\n']

Packages

[tool.poetry.dependencies]
python = "^3.11"
accelerate = "^1.0.1"
bitsandbytes = "^0.44.1"
datasets = "^2.16.1"
peft = "^0.13.2"
torch = "2.2.1"
transformers = "^4.45.2"
trl = "^0.11.4"
Downloads last month
0
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for VitalyProtasov/Revolut-Query-Clarification-LoRA-LLAMA-8b

Adapter
(7)
this model