Korean Reranker Training on Amazon SageMaker

ํ•œ๊ตญ์–ด Reranker ๊ฐœ๋ฐœ์„ ์œ„ํ•œ ํŒŒ์ธํŠœ๋‹ ๊ฐ€์ด๋“œ๋ฅผ ์ œ์‹œํ•ฉ๋‹ˆ๋‹ค.

ko-reranker๋Š” BAAI/bge-reranker-larger ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ fine-tuned model ์ž…๋‹ˆ๋‹ค.
๋ณด๋‹ค ์ž์„ธํ•œ ์‚ฌํ•ญ์€ korean-reranker-git / AWS Blog, ํ•œ๊ตญ์–ด Reranker๋ฅผ ํ™œ์šฉํ•œ ๊ฒ€์ƒ‰ ์ฆ๊ฐ• ์ƒ์„ฑ(RAG) ์„ฑ๋Šฅ ์˜ฌ๋ฆฌ๊ธฐ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”


0. Features

  • Reranker๋Š” ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ์งˆ๋ฌธ๊ณผ ๋ฌธ์„œ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๋ฉฐ ์ž„๋ฒ ๋”ฉ ๋Œ€์‹  ์œ ์‚ฌ๋„๋ฅผ ์ง์ ‘ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

  • Reranker์— ์งˆ๋ฌธ๊ณผ ๊ตฌ์ ˆ์„ ์ž…๋ ฅํ•˜๋ฉด ์—ฐ๊ด€์„ฑ ์ ์ˆ˜๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • Reranker๋Š” CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ตœ์ ํ™”๋˜๋ฏ€๋กœ ๊ด€๋ จ์„ฑ ์ ์ˆ˜๊ฐ€ ํŠน์ • ๋ฒ”์œ„์— ๊ตญํ•œ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

1.Usage

  • using Transformers
    def exp_normalize(x):
      b = x.max()
      y = np.exp(x - b)
      return y / y.sum()
    
    from transformers import AutoModelForSequenceClassification, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    model.eval()

    pairs = [["๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"], \
             ["๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"]]

    with torch.no_grad():
        inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
        scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
        scores = exp_normalize(scores.numpy())
        print (f'first: {scores[0]}, second: {scores[1]}')
  • using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
    'HF_TASK':'text-classification'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.28.1',
    pytorch_version='2.0.0',
    py_version='py310',
    env=hub,
    role=role, 
)

# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
    initial_instance_count=1, # number of instances
    instance_type='ml.g5.large' # ec2 instance type
)

runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
    {
        "inputs": [
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‹ซ์–ดํ•ด", "text_pair": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์‚ฌ๋ž‘ํ•ด"},
            {"text": "๋‚˜๋Š” ๋„ˆ๋ฅผ ์ข‹์•„ํ•ด", "text_pair": "๋„ˆ์— ๋Œ€ํ•œ ๋‚˜์˜ ๊ฐ์ •์€ ์‚ฌ๋ž‘ ์ผ ์ˆ˜๋„ ์žˆ์–ด"}
        ]
    }
)

response = runtime_client.invoke_endpoint(
    EndpointName="<endpoint-name>",
    ContentType="application/json",
    Accept="application/json",
    Body=payload
)

## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')

2. Backgound

  • ์ปจํƒ์ŠคํŠธ ์ˆœ์„œ๊ฐ€ ์ •ํ™•๋„์— ์˜ํ–ฅ ์ค€๋‹ค(Lost in Middle, Liu et al., 2023)

  • Reranker ์‚ฌ์šฉํ•ด์•ผ ํ•˜๋Š” ์ด์œ 

    • ํ˜„์žฌ LLM์€ context ๋งŽ์ด ๋„ฃ๋Š”๋‹ค๊ณ  ์ข‹์€๊ฑฐ ์•„๋‹˜, relevantํ•œ๊ฒŒ ์ƒ์œ„์— ์žˆ์–ด์•ผ ์ •๋‹ต์„ ์ž˜ ๋งํ•ด์ค€๋‹ค
    • Semantic search์—์„œ ์‚ฌ์šฉํ•˜๋Š” similarity(relevant) score๊ฐ€ ์ •๊ตํ•˜์ง€ ์•Š๋‹ค. (์ฆ‰, ์ƒ์œ„ ๋žญ์ปค๋ฉด ํ•˜์œ„ ๋žญ์ปค๋ณด๋‹ค ํ•ญ์ƒ ๋” ์งˆ๋ฌธ์— ์œ ์‚ฌํ•œ ์ •๋ณด๊ฐ€ ๋งž์•„?)
      • Embedding์€ meaning behind document๋ฅผ ๊ฐ€์ง€๋Š” ๊ฒƒ์— ํŠนํ™”๋˜์–ด ์žˆ๋‹ค.
      • ์งˆ๋ฌธ๊ณผ ์ •๋‹ต์ด ์˜๋ฏธ์ƒ ๊ฐ™์€๊ฑด ์•„๋‹ˆ๋‹ค. (Hypothetical Document Embeddings)
      • ANNs(Approximate Nearest Neighbors) ์‚ฌ์šฉ์— ๋”ฐ๋ฅธ ํŒจ๋„ํ‹ฐ

3. Reranker models


4. Dataset

  • msmarco-triplets

    • (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
    • ํ•ด๋‹น ๋ฐ์ดํ„ฐ ์…‹์€ ์˜๋ฌธ์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
    • Amazon Translate ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฒˆ์—ญํ•˜์—ฌ ํ™œ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.
  • Format

{"query": str, "pos": List[str], "neg": List[str]}
  • Query๋Š” ์งˆ๋ฌธ์ด๊ณ , pos๋Š” ๊ธ์ • ํ…์ŠคํŠธ ๋ชฉ๋ก, neg๋Š” ๋ถ€์ • ํ…์ŠคํŠธ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๋ถ€์ • ํ…์ŠคํŠธ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ์ „์ฒด ๋ง๋ญ‰์น˜์—์„œ ์ผ๋ถ€๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•˜์—ฌ ๋ถ€์ • ํ…์ŠคํŠธ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • Example

{"query": "๋Œ€ํ•œ๋ฏผ๊ตญ์˜ ์ˆ˜๋„๋Š”?", "pos": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„์ฟ„์ด๋ฉฐ ํ•œ๊ตญ์€ ์„œ์šธ์ด๋‹ค."], "neg": ["๋ฏธ๊ตญ์˜ ์ˆ˜๋„๋Š” ์›Œ์‹ฑํ„ด์ด๊ณ , ์ผ๋ณธ์€ ๋„์ฟ„์ด๋ฉฐ ๋ถํ•œ์€ ํ‰์–‘์ด๋‹ค."]}

5. Performance

Model has-right-in-contexts mrr (mean reciprocal rank)
without-reranker (default) 0.93 0.80
with-reranker (bge-reranker-large) 0.95 0.84
with-reranker (fine-tuned using korean) 0.96 0.87
  • evaluation set:
./dataset/evaluation/eval_dataset.csv
  • training parameters:
{
    "learning_rate": 5e-6,
    "fp16": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "train_group_size": 3,
    "max_len": 512,
    "weight_decay": 0.01,
}

6. Acknowledgement


7. Citation

  • If you find this repository useful, please consider giving a like โญ and citation

8. Contributors:

  • Dongjin Jang, Ph.D. (AWS AI/ML Specislist Solutions Architect) | Mail | Linkedin | Git |

9. License

10. Analytics

  • Hits
Downloads last month
4,797
Safetensors
Model size
560M params
Tensor type
F32
ยท
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.