|
--- |
|
language: |
|
- en |
|
library_name: transformers |
|
tags: |
|
- cross-encoder |
|
- search |
|
- product-search |
|
base_model: cross-encoder/ms-marco-MiniLM-L-12-v2 |
|
model-index: |
|
- name: esci-ms-marco-MiniLM-L-12-v2 |
|
results: |
|
- task: |
|
type: text-classification |
|
metrics: |
|
- type: mrr@10 |
|
value: 91.81 |
|
- type: ndcg@10 |
|
value: 85.46 |
|
--- |
|
|
|
# Model Descripton |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
Fine tunes a cross encoder on the Amazon ESCI dataset. |
|
|
|
# Usage |
|
|
|
## Transformers |
|
|
|
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app --> |
|
|
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from torch import no_grad |
|
|
|
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2" |
|
|
|
queries = [ |
|
"adidas shoes", |
|
"adidas shoes", |
|
"girls sandals", |
|
"backpacks", |
|
"shoes", |
|
"mustard sleeveless gown" |
|
] |
|
documents = [ |
|
'{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}', |
|
'{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}', |
|
'{"title": "Womens sandals", "description": "Sandals: wide width 9", "brand": "Chacos", "color": "blue"}', |
|
'{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}', |
|
'{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}', |
|
'{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}' |
|
] |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
inputs = tokenizer( |
|
queries, |
|
documents, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
model.eval() |
|
with no_grad(): |
|
scores = model(**inputs).logits.cpu().detach().numpy() |
|
print(scores) |
|
``` |
|
|
|
### Sentence Transformers |
|
|
|
```python |
|
from sentence_transformers import CrossEncoder |
|
|
|
model_name = "lv12/esci-ms-marco-MiniLM-L-12-v2" |
|
|
|
|
|
queries = [ |
|
"adidas shoes", |
|
"adidas shoes", |
|
"girls sandals", |
|
"backpacks", |
|
"shoes", |
|
"mustard sleeveless gown" |
|
] |
|
documents = [ |
|
'{"title": "Nike Air Max", "description": "The best shoes you can get, with air cushion", "brand": "Nike", "color": "black"}', |
|
'{"title": "Adidas Ultraboost", "description": "The shoes that represent the world", "brand": "Adidas", "color": "white"}', |
|
'{"title": "Womens sandals", "description": "Sandals: wide width 9", "brand": "Chacos", "color": "blue"}', |
|
'{"title": "Girls surf backpack", "description": "The best backpack in town", "brand": "Roxy", "color": "pink"}', |
|
'{"title": "Fresh watermelon", "description": "The best fruit in town, all you can eat", "brand": "Fruitsellers Inc.", "color": "green"}', |
|
'{"title": "Floral yellow dress with frills and lace", "description": "Brighten up your summers with a gorgeous dress", "brand": "Dressmakers Inc.", "color": "bright yellow"}' |
|
] |
|
model = CrossEncoder(model_name, max_length=512) |
|
scores = model.predict([(q, d) for q, d in zip(queries, documents)]) |
|
print(scores) |
|
``` |
|
|
|
```bash |
|
[ 1.057739 1.6751697 1.039221 1.5969192 -0.8867093 0.5035825 ] |
|
``` |
|
|
|
## Training |
|
|
|
Trained using `CrossEntropyLoss` using `<query, document>` pairs with `grade` as the label. |
|
|
|
```python |
|
from sentence_transformers import InputExample |
|
|
|
train_samples = [ |
|
InputExample(texts=["query 1", "document 1"], label=0.3), |
|
InputExample(texts=["query 1", "document 2"], label=0.8), |
|
InputExample(texts=["query 2", "document 2"], label=0.1), |
|
] |
|
```` |