File size: 7,010 Bytes
c24672d
 
 
 
 
091e8de
 
c24672d
36d796b
 
 
5861f85
7980dfb
 
 
36d796b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6c7e2d
 
36d796b
 
a6c7e2d
 
 
 
36d796b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6c7e2d
 
36d796b
 
a6c7e2d
 
 
 
36d796b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6c7e2d
 
36d796b
a6c7e2d
 
36d796b
 
 
 
 
 
c24672d
 
 
a6c7e2d
c24672d
 
 
36d796b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
license: mit
language:
- ru
- en
base_model:
- ai-forever/FRED-T5-1.7B
---

---
# Model Card for FRIDA
## FRIDA full-scaled finetuned retrieval model inspired by denoising architecture based on T5
<figure>
  <img src="img.jpg">
</figure>
The FRIDA is a general text embedding model for Russian. The model is based on the encoder part of FRED-T5 (https://huggingface.co/ai-forever/FRED-T5-1.7B). It has been pre-trained on a Russian-English dataset and fine-tuned for improved performance on the target task.

For more model details please refer to our [article](TODO).

## Usage

The model can be used as is with prefixes. It is recommended to use CLS pooling. The choice of prefix and pooling depends on the task. 

We use the following basic rules to choose a prefix:
- `"search_query: "` and `"search_document: "` prefixes are for answer or relevant paragraph retrieval
- `"paraphrase: "` prefix is for symmetric paraphrasing related tasks (STS, paraphrase mining, deduplication)
- `"categorize: "` prefix is for asymmetric matching of document title and body (e.g. news, scientific papers, social posts)
- `"categorize_sentiment: "` prefix is for any tasks that rely on sentiment features (e.g. hate, toxic, emotion)
- `"categorize_topic: "` prefix is ​​intended for tasks where you need to group texts by topic
- `"categorize_entailment: "` prefix is for textual entailment task (NLI)

To better tailor the model to your needs, you can fine-tune it with relevant high-quality Russian and English datasets.

Below are examples of texts encoding using the Transformers and SentenceTransformers libraries.

### Transformers

```python
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, T5EncoderModel


def pool(hidden_state, mask, pooling_method="cls"):
    if pooling_method == "mean":
        s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
        d = mask.sum(axis=1, keepdim=True).float()
        return s / d
    elif pooling_method == "cls":
        return hidden_state[:, 0]

inputs = [
    # 
    "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",
    "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",
    "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
    # 
    "paraphrase: Ярославским баням разрешили работать без посетителей",
    "categorize_entailment: Женщину спасают врачи.",
    "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."
    ]

tokenizer = AutoTokenizer.from_pretrained("ai-forever/FRIDA")
model = T5EncoderModel.from_pretrained("ai-forever/FRIDA")

tokenized_inputs = tokenizer(inputs, max_length=512, padding=True, truncation=True, return_tensors="pt")

with torch.no_grad():
    outputs = model(**tokenized_inputs)
    
embeddings = pool(
    outputs.last_hidden_state, 
    tokenized_inputs["attention_mask"],
    pooling_method="cls" # or try "mean"
)

embeddings = F.normalize(embeddings, p=2, dim=1)
sim_scores = embeddings[:3] @ embeddings[3:].T
print(sim_scores.diag().tolist())
# [0.4796873927116394, 0.9409002065658569, 0.7761015892028809]
```

### SentenceTransformers

```python
from sentence_transformers import SentenceTransformer

inputs = [
    # 
    "paraphrase: В Ярославской области разрешили работу бань, но без посетителей",
    "categorize_entailment: Женщину доставили в больницу, за ее жизнь сейчас борются врачи.",
    "search_query: Сколько программистов нужно, чтобы вкрутить лампочку?",
    # 
    "paraphrase: Ярославским баням разрешили работать без посетителей",
    "categorize_entailment: Женщину спасают врачи.",
    "search_document: Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование."
    ]

# loads model with CLS pooling
model = SentenceTransformer("ai-forever/FRIDA")

# embeddings are normalized by default
embeddings = model.encode(inputs, convert_to_tensor=True)

sim_scores = embeddings[:3] @ embeddings[3:].T
print(sim_scores.diag().tolist())
# [0.47968706488609314, 0.940900444984436, 0.7761018872261047]
```

or using prompts (sentence-transformers>=2.4.0):

```python
from sentence_transformers import SentenceTransformer

# loads model with CLS pooling
model = SentenceTransformer("ai-forever/FRIDA")

paraphrase = model.encode(["В Ярославской области разрешили работу бань, но без посетителей", "Ярославским баням разрешили работать без посетителей"], prompt_name="paraphrase")
print(paraphrase[0] @ paraphrase[1].T) # 0.47968706488609314

categorize_entailment = model.encode(["Женщину доставили в больницу, за ее жизнь сейчас борются врачи.", "Женщину спасают врачи."], prompt_name="categorize_entailment")
print(categorize_entailment[0] @ categorize_entailment[1].T) # 0.940900444984436

query_embedding = model.encode("Сколько программистов нужно, чтобы вкрутить лампочку?", prompt_name="search_query")
document_embedding = model.encode("Чтобы вкрутить лампочку, требуется три программиста: один напишет программу извлечения лампочки, другой — вкручивания лампочки, а третий проведет тестирование.", prompt_name="search_document")
print(query_embedding @ document_embedding.T) # 0.7761018872261047
```

+ # Authors
+ [SaluteDevices](https://sberdevices.ru/) AI for B2C RnD Team.
+ Artem Snegirev: [HF profile](https://huggingface.co/artemsnegirev);
+ Anna Maksimova [HF profile](https://huggingface.co/anpalmak);
+ Aleksandr Abramov: [HF profile](https://huggingface.co/Andrilko), [Github](https://github.com/Ab1992ao), [Kaggle Competitions Master](https://www.kaggle.com/andrilko)


## Citation

```
@misc{TODO
}
```

## Limitations

The model is designed to process texts in Russian, the quality in English is unknown. Maximum input text length is limited to 512 tokens.