Korean Reranker Training on Amazon SageMaker
ํ๊ตญ์ด Reranker ๊ฐ๋ฐ์ ์ํ ํ์ธํ๋ ๊ฐ์ด๋๋ฅผ ์ ์ํฉ๋๋ค.
ko-reranker๋ BAAI/bge-reranker-larger ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ฐ์ดํฐ์ ๋ํ fine-tuned model ์
๋๋ค.
๋ณด๋ค ์์ธํ ์ฌํญ์ korean-reranker-git / AWS Blog, ํ๊ตญ์ด Reranker๋ฅผ ํ์ฉํ ๊ฒ์ ์ฆ๊ฐ ์์ฑ(RAG) ์ฑ๋ฅ ์ฌ๋ฆฌ๊ธฐ์ ์ฐธ๊ณ ํ์ธ์
0. Features
Reranker๋ ์๋ฒ ๋ฉ ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ง๋ฌธ๊ณผ ๋ฌธ์๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ฉฐ ์๋ฒ ๋ฉ ๋์ ์ ์ฌ๋๋ฅผ ์ง์ ์ถ๋ ฅํฉ๋๋ค.
Reranker์ ์ง๋ฌธ๊ณผ ๊ตฌ์ ์ ์ ๋ ฅํ๋ฉด ์ฐ๊ด์ฑ ์ ์๋ฅผ ์ป์ ์ ์์ต๋๋ค.
Reranker๋ CrossEntropy loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ต์ ํ๋๋ฏ๋ก ๊ด๋ จ์ฑ ์ ์๊ฐ ํน์ ๋ฒ์์ ๊ตญํ๋์ง ์์ต๋๋ค.
1.Usage
- using Transformers
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
pairs = [["๋๋ ๋๋ฅผ ์ซ์ดํด", "๋๋ ๋๋ฅผ ์ฌ๋ํด"], \
["๋๋ ๋๋ฅผ ์ข์ํด", "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"]]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = exp_normalize(scores.numpy())
print (f'first: {scores[0]}, second: {scores[1]}')
- using SageMaker
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
# Hub Model configuration. https://huggingface.co/models
hub = {
'HF_MODEL_ID':'Dongjin-kr/ko-reranker',
'HF_TASK':'text-classification'
}
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
transformers_version='4.28.1',
pytorch_version='2.0.0',
py_version='py310',
env=hub,
role=role,
)
# deploy model to SageMaker Inference
predictor = huggingface_model.deploy(
initial_instance_count=1, # number of instances
instance_type='ml.g5.large' # ec2 instance type
)
runtime_client = boto3.Session().client('sagemaker-runtime')
payload = json.dumps(
{
"inputs": [
{"text": "๋๋ ๋๋ฅผ ์ซ์ดํด", "text_pair": "๋๋ ๋๋ฅผ ์ฌ๋ํด"},
{"text": "๋๋ ๋๋ฅผ ์ข์ํด", "text_pair": "๋์ ๋ํ ๋์ ๊ฐ์ ์ ์ฌ๋ ์ผ ์๋ ์์ด"}
]
}
)
response = runtime_client.invoke_endpoint(
EndpointName="<endpoint-name>",
ContentType="application/json",
Accept="application/json",
Body=payload
)
## deserialization
out = json.loads(response['Body'].read().decode()) ## for json
print (f'Response: {out}')
2. Backgound
์ปจํ์คํธ ์์๊ฐ ์ ํ๋์ ์ํฅ ์ค๋ค(Lost in Middle, Liu et al., 2023)
Reranker ์ฌ์ฉํด์ผ ํ๋ ์ด์
- ํ์ฌ LLM์ context ๋ง์ด ๋ฃ๋๋ค๊ณ ์ข์๊ฑฐ ์๋, relevantํ๊ฒ ์์์ ์์ด์ผ ์ ๋ต์ ์ ๋งํด์ค๋ค
- Semantic search์์ ์ฌ์ฉํ๋ similarity(relevant) score๊ฐ ์ ๊ตํ์ง ์๋ค. (์ฆ, ์์ ๋ญ์ปค๋ฉด ํ์ ๋ญ์ปค๋ณด๋ค ํญ์ ๋ ์ง๋ฌธ์ ์ ์ฌํ ์ ๋ณด๊ฐ ๋ง์?)
- Embedding์ meaning behind document๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ํนํ๋์ด ์๋ค.
- ์ง๋ฌธ๊ณผ ์ ๋ต์ด ์๋ฏธ์ ๊ฐ์๊ฑด ์๋๋ค. (Hypothetical Document Embeddings)
- ANNs(Approximate Nearest Neighbors) ์ฌ์ฉ์ ๋ฐ๋ฅธ ํจ๋ํฐ
3. Reranker models
[Cohere] Reranker
[BAAI] bge-reranker-large
[BAAI] bge-reranker-base
4. Dataset
msmarco-triplets
- (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
- ํด๋น ๋ฐ์ดํฐ ์ ์ ์๋ฌธ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
- Amazon Translate ๊ธฐ๋ฐ์ผ๋ก ๋ฒ์ญํ์ฌ ํ์ฉํ์์ต๋๋ค.
Format
{"query": str, "pos": List[str], "neg": List[str]}
Query๋ ์ง๋ฌธ์ด๊ณ , pos๋ ๊ธ์ ํ ์คํธ ๋ชฉ๋ก, neg๋ ๋ถ์ ํ ์คํธ ๋ชฉ๋ก์ ๋๋ค. ์ฟผ๋ฆฌ์ ๋ํ ๋ถ์ ํ ์คํธ๊ฐ ์๋ ๊ฒฝ์ฐ ์ ์ฒด ๋ง๋ญ์น์์ ์ผ๋ถ๋ฅผ ๋ฌด์์๋ก ์ถ์ถํ์ฌ ๋ถ์ ํ ์คํธ๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค.
Example
{"query": "๋ํ๋ฏผ๊ตญ์ ์๋๋?", "pos": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋์ฟ์ด๋ฉฐ ํ๊ตญ์ ์์ธ์ด๋ค."], "neg": ["๋ฏธ๊ตญ์ ์๋๋ ์์ฑํด์ด๊ณ , ์ผ๋ณธ์ ๋์ฟ์ด๋ฉฐ ๋ถํ์ ํ์์ด๋ค."]}
5. Performance
Model | has-right-in-contexts | mrr (mean reciprocal rank) |
---|---|---|
without-reranker (default) | 0.93 | 0.80 |
with-reranker (bge-reranker-large) | 0.95 | 0.84 |
with-reranker (fine-tuned using korean) | 0.96 | 0.87 |
- evaluation set:
./dataset/evaluation/eval_dataset.csv
- training parameters:
{
"learning_rate": 5e-6,
"fp16": True,
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 32,
"train_group_size": 3,
"max_len": 512,
"weight_decay": 0.01,
}
6. Acknowledgement
- Part of the code is developed based on FlagEmbedding and KoSimCSE-SageMaker.
7. Citation
- If you find this repository useful, please consider giving a like โญ and citation
8. Contributors:
9. License
- FlagEmbedding is licensed under the MIT License.
10. Analytics
- Downloads last month
- 4,797
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.