|
def getLengthParam(text: str, tokenizer) -> str: |
|
tokens_count = len(tokenizer.encode(text)) |
|
if tokens_count <= 15: |
|
len_param = '1' |
|
elif tokens_count <= 50: |
|
len_param = '2' |
|
elif tokens_count <= 256: |
|
len_param = '3' |
|
else: |
|
len_param = '-' |
|
return len_param |
|
|
|
|
|
def calcAnswerLengthByProbability(lengthId): |
|
|
|
|
|
|
|
def getLenght(probList): |
|
rndNum = random.randrange(start=0, stop=100, step=1) |
|
if 0 <= rndNum <= probList[0]: |
|
return 3 |
|
elif probList[0] < rndNum <= probList[1]: |
|
return 2 |
|
else: |
|
return 1 |
|
|
|
return { |
|
lengthId == '3' or lengthId == '-': getLenght([60, 90]), |
|
lengthId == '2': getLenght([25, 75]), |
|
lengthId == '1': getLenght([20, 50]), |
|
}[True] |
|
|
|
|
|
|
|
|
|
def cropContext(tensor, size): |
|
|
|
tensor = tensor[-1] |
|
|
|
beginList = [] |
|
|
|
for i, item in enumerate(tensor): |
|
if (i < len(tensor) - 5 and item == 96 and tensor[i + 2] == 96 and tensor[i + 4] == 96): |
|
beginList.append(i) |
|
|
|
if (len(beginList) < size): |
|
return torch.unsqueeze(tensor, 0) |
|
|
|
neededIndex = beginList[-size] |
|
|
|
return torch.unsqueeze(tensor[neededIndex:], 0) |