File size: 2,182 Bytes
b650cfe 65df198 b650cfe 7e35601 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import random
import torch
def getLengthParam(text: str, tokenizer) -> str:
tokens_count = len(tokenizer.encode(text))
if tokens_count <= 15:
len_param = '1'
elif tokens_count <= 50:
len_param = '2'
elif tokens_count <= 256:
len_param = '3'
else:
len_param = '-'
return len_param
# Эта функция вычисляет длину ожидаемого ответа на основе инпута
def calcAnswerLengthByProbability(lengthId):
# Вспомогательная функция, для работы с вероятностями
# На вход подаем список веротностей для длинного ответа (3), среднего(2), короткого 1
def getLenght(probList):
rndNum = random.randrange(start=0, stop=100, step=1)
if 0 <= rndNum <= probList[0]:
return 3
elif probList[0] < rndNum <= probList[1]:
return 2
else:
return 1
return {
lengthId == '3' or lengthId == '-': getLenght([60, 90]), # до 60 - 3, от 60 до 90 2, остальное - 1
lengthId == '2': getLenght([25, 75]), # до 25 - 3, от 25 до 75 - 2, остальное - 2
lengthId == '1': getLenght([20, 50]), # до 20 - 3, от 20 до 50 - 2, остальное - 1
}[True]
# Функция для обрезки контекста
# tensor - входной тензор
# size - сколько ПОСЛЕДНИХ ответов нужно оставить
def cropContext(tensor, size):
# переводим в размерность, удобную для работы
tensor = tensor[-1]
# Список, содержащий начала предложений
beginList = []
for i, item in enumerate(tensor):
if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96):
beginList.append(i)
if (len(beginList) < size):
return torch.unsqueeze(tensor, 0)
neededIndex = beginList[-size]
# Возвращаем в нужном нам формате (добавляем одну размерность)
return torch.unsqueeze(tensor[neededIndex:], 0) |