Spaces:
Runtime error
Runtime error
File size: 2,907 Bytes
122d428 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
if __name__ == '__main__':
inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd']
from transformers import AutoTokenizer
from model import CustomModel
import torch
from configuration import CFG
from dataset import SingleInputDataset
from torch.utils.data import DataLoader
from utils import inference_fn, get_char_probs, get_results, get_text
import numpy as np
import gradio as gr
import os
device = torch.device('cpu')
config_path = os.path.join('models_file', 'config.pth')
model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth')
tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer')
model = CustomModel(CFG, config_path=config_path, pretrained=False)
state = torch.load(model_path,
map_location=torch.device('cpu'))
model.load_state_dict(state['model'])
def get_answer(context, feature):
## Input to the model using patient-history and feature-text
inputs_single = tokenizer(context, feature,
add_special_tokens=True,
max_length=CFG.max_len,
padding="max_length",
return_offsets_mapping=False)
for k, v in inputs_single.items():
inputs_single[k] = torch.tensor(v, dtype=torch.long)
# Create a new dataset containing only the input sample
single_input_dataset = SingleInputDataset(inputs_single)
# Create a DataLoader for the new dataset
single_input_loader = DataLoader(
single_input_dataset,
batch_size=1,
shuffle=False,
num_workers=2
)
# Perform inference on the single input
output = inference_fn(single_input_loader, model, device)
prediction = output.reshape((1, CFG.max_len))
char_probs = get_char_probs([context], prediction, tokenizer)
predictions = np.mean([char_probs], axis=0)
results = get_results(predictions, th=0.5)
print(results)
return get_text(context, results[0])
inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)]
output = gr.outputs.Textbox(label="Answer")
article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>"
app = gr.Interface(
fn=get_answer,
inputs=inputs,
outputs=output,
allow_flagging='never',
title="Entity Extraction from Text",
article=article,
enable_queue=True,
cache_examples=False)
app.launch()
|