import itertools import torch import numpy as np from tqdm.auto import tqdm def get_char_probs(texts, predictions, tokenizer): """ Maps prediction from encoded offset mapping to the text Prediction = 466 sequence length * batch text = 768 * batch Using offset mapping [(0, 4), ] -- 466 creates results that is size of texts for each text result[i] result[0, 4] = pred[0] like wise for all """ results = [np.zeros(len(t)) for t in texts] for i, (text, prediction) in enumerate(zip(texts, predictions)): encoded = tokenizer(text, add_special_tokens=True, return_offsets_mapping=True) for idx, (offset_mapping, pred) in enumerate(zip(encoded['offset_mapping'], prediction)): start = offset_mapping[0] end = offset_mapping[1] results[i][start:end] = pred return results def get_results(char_probs, th=0.5): """ Get the list of probabilites with size of text And then get the index of the characters which are more than th example: char_prob = [0.1, 0.1, 0.9, 0.9, 0.9, 0.9, 0.2, 0.2, 0.2, 0.7, 0.7, 0.7] ## length == 766 where > 0.5 index ## [ 2, 3, 4, 5, 9, 10, 11] Groupby same one -- [[2, 3, 4, 5], [9, 10, 11]] And get the max and min and output the results """ results = [] for char_prob in char_probs: result = np.where(char_prob >= th)[0] + 1 result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))] result = [f"{min(r)} {max(r)}" for r in result] result = ";".join(result) results.append(result) return results def get_predictions(results): """ Will get the location, as a string, just like location in the df results = ['2 5', '9 11'] loop through, split it and save it as start and end and store it in array """ predictions = [] for result in results: prediction = [] if result != "": for loc in [s.split() for s in result.split(';')]: start, end = int(loc[0]), int(loc[1]) prediction.append([start, end]) predictions.append(prediction) return predictions def inference_fn(test_loader, model, device): preds = [] model.eval() model.to(device) tk0 = tqdm(test_loader, total=len(test_loader)) for inputs in tk0: for k, v in inputs.items(): inputs[k] = v.to(device) with torch.no_grad(): y_preds = model(inputs) preds.append(y_preds.sigmoid().numpy()) predictions = np.concatenate(preds) return predictions def get_text(context, indexes): if (indexes): if ';' in indexes: list_indexes = indexes.split(';') answer = '' for idx in list_indexes: start_index = int(idx.split(' ')[0]) end_index = int(idx.split(' ')[1]) answer += ' ' answer += context[start_index:end_index] return answer else: start_index = int(indexes.split(' ')[0]) end_index = int(indexes.split(' ')[1]) return context[start_index:end_index] else: return 'Not found in this Context'