This generation model is based on sberbank-ai/rugpt3small_based_on_gpt2. It's trained on large corpus of dialog data and can be used for buildning generative conversational agents
The model was trained with context size 3
On a private validation set we calculated metrics introduced in this paper:
- Sensibleness: Crowdsourcers were asked whether model's response makes sense given the context
- Specificity: Crowdsourcers were asked whether model's response is specific for given context, in other words we don't want our model to give general and boring responses
- SSA which is the average of two metrics above (Sensibleness Specificity Average)
sensibleness | specificity | SSA | |
---|---|---|---|
tinkoff-ai/ruDialoGPT-small | 0.64 | 0.5 | 0.57 |
tinkoff-ai/ruDialoGPT-medium | 0.78 | 0.69 | 0.735 |
How to use:
import torch
from transformers import AutoTokenizer, AutoModelWithLMHead
tokenizer = AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-small')
model = AutoModelWithLMHead.from_pretrained('tinkoff-ai/ruDialoGPT-small')
inputs = tokenizer('@@ПЕРВЫЙ@@ привет @@ВТОРОЙ@@ привет @@ПЕРВЫЙ@@ как дела? @@ВТОРОЙ@@', return_tensors='pt')
generated_token_ids = model.generate(
**inputs,
top_k=10,
top_p=0.95,
num_beams=3,
num_return_sequences=3,
do_sample=True,
no_repeat_ngram_size=2,
temperature=1.2,
repetition_penalty=1.2,
length_penalty=1.0,
eos_token_id=50257,
max_new_tokens=40
)
context_with_response = [tokenizer.decode(sample_token_ids) for sample_token_ids in generated_token_ids]
context_with_response
- Downloads last month
- 1,365
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.