RishuD7's picture
device changed to cpu
b6b59b8
raw
history blame
2.93 kB
if __name__ == '__main__':
inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd']
from transformers import AutoTokenizer
from model import CustomModel
import torch
from configuration import CFG
from dataset import SingleInputDataset
from torch.utils.data import DataLoader
from utils import inference_fn, get_char_probs, get_results, get_text
import numpy as np
import gradio as gr
import os
device = torch.device('cpu')
config_path = os.path.join('models_file', 'config.pth')
model_path = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth')
tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer')
model = CustomModel(CFG, config_path=config_path, pretrained=False)
state = torch.load(model_path,
map_location=device)
model.load_state_dict(state['model'])
def get_answer(context, feature):
## Input to the model using patient-history and feature-text
inputs_single = tokenizer(context, feature,
add_special_tokens=True,
max_length=CFG.max_len,
padding="max_length",
return_offsets_mapping=False)
for k, v in inputs_single.items():
inputs_single[k] = torch.tensor(v, dtype=torch.long)
# Create a new dataset containing only the input sample
single_input_dataset = SingleInputDataset(inputs_single)
# Create a DataLoader for the new dataset
single_input_loader = DataLoader(
single_input_dataset,
batch_size=1,
shuffle=False,
num_workers=2
)
# Perform inference on the single input
output = inference_fn(single_input_loader, model, device)
prediction = output.reshape((1, CFG.max_len))
char_probs = get_char_probs([context], prediction, tokenizer)
predictions = np.mean([char_probs], axis=0)
results = get_results(predictions, th=0.5)
print(results)
return get_text(context, results[0])
inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)]
output = gr.outputs.Textbox(label="Answer")
article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>"
app = gr.Interface(
fn=get_answer,
inputs=inputs,
outputs=output,
allow_flagging='never',
title="Phrase Extraction",
article=article,
enable_queue=True,
cache_examples=False,
css="footer {visibility: hidden}"
)
app.launch()