File size: 2,182 Bytes
b650cfe
65df198
b650cfe
7e35601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import random
import torch

def getLengthParam(text: str, tokenizer) -> str:
    tokens_count = len(tokenizer.encode(text))
    if tokens_count <= 15:
        len_param = '1'
    elif tokens_count <= 50:
        len_param = '2'
    elif tokens_count <= 256:
        len_param = '3'
    else:
        len_param = '-'
    return len_param

# Эта функция вычисляет длину ожидаемого ответа на основе инпута
def calcAnswerLengthByProbability(lengthId):

  # Вспомогательная функция, для работы с вероятностями
  # На вход подаем список веротностей для длинного ответа (3), среднего(2), короткого 1 
  def getLenght(probList):
    rndNum = random.randrange(start=0, stop=100, step=1)
    if 0 <= rndNum <= probList[0]:
      return 3
    elif probList[0] < rndNum <= probList[1]:
      return 2
    else:
      return 1

  return {
            lengthId == '3' or lengthId == '-': getLenght([60, 90]), # до 60 - 3, от 60 до 90 2, остальное - 1
            lengthId == '2': getLenght([25, 75]), # до 25 - 3, от 25 до 75 - 2, остальное - 2
            lengthId == '1': getLenght([20, 50]), # до 20 - 3, от 20 до 50 - 2, остальное - 1
    }[True]

# Функция для обрезки контекста
# tensor - входной тензор
# size - сколько ПОСЛЕДНИХ ответов нужно оставить
def cropContext(tensor, size):
  # переводим в размерность, удобную для работы
  tensor = tensor[-1]
  # Список, содержащий начала предложений
  beginList = []

  for i, item in enumerate(tensor):
    if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96):
      beginList.append(i)

  if (len(beginList) < size):
    return torch.unsqueeze(tensor, 0)

  neededIndex = beginList[-size]
  # Возвращаем в нужном нам формате (добавляем одну размерность)
  return torch.unsqueeze(tensor[neededIndex:], 0)