File size: 2,907 Bytes
122d428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
if __name__ == '__main__':    
    inputs = ['gbjjhbdjhbdgjhdbfjhsdkjrkjf', 'fdjhbjhsbd']
    from transformers import AutoTokenizer
    from model import CustomModel
    import torch
    from configuration import CFG
    from dataset import SingleInputDataset
    from torch.utils.data import DataLoader
    from utils import inference_fn, get_char_probs, get_results, get_text
    import numpy as np
    import gradio as gr
    import os

    device = torch.device('cpu')
    config_path = os.path.join('models_file', 'config.pth')
    model_path  = os.path.join('models_file', 'microsoft-deberta-base_0.9449373420387531_8_best.pth')
    tokenizer = AutoTokenizer.from_pretrained('models_file/tokenizer')
    model = CustomModel(CFG, config_path=config_path, pretrained=False)
    state = torch.load(model_path,
                        map_location=torch.device('cpu'))
    model.load_state_dict(state['model'])

    def get_answer(context, feature):

        ## Input to the model using patient-history and feature-text
        inputs_single = tokenizer(context, feature, 
                                    add_special_tokens=True,
                                    max_length=CFG.max_len,
                                    padding="max_length",
                                    return_offsets_mapping=False)

        for k, v in inputs_single.items():
            inputs_single[k] = torch.tensor(v, dtype=torch.long)

        # Create a new dataset containing only the input sample
        single_input_dataset = SingleInputDataset(inputs_single)
        # Create a DataLoader for the new dataset
        single_input_loader = DataLoader(
                                            single_input_dataset,
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=2
                                        )

        # Perform inference on the single input
        output = inference_fn(single_input_loader, model, device)

        prediction = output.reshape((1, CFG.max_len))
        char_probs = get_char_probs([context], prediction, tokenizer)
        predictions = np.mean([char_probs], axis=0)
        results = get_results(predictions, th=0.5)

        print(results)
        return get_text(context, results[0])
    
    inputs = [gr.inputs.Textbox(label="Context Para", lines=10), gr.inputs.Textbox(label="Question", lines=1)]
    output = gr.outputs.Textbox(label="Answer")
    article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>"

    app = gr.Interface(
        fn=get_answer, 
        inputs=inputs, 
        outputs=output,
        allow_flagging='never', 
        title="Entity Extraction from Text", 
        article=article,
        enable_queue=True,
        cache_examples=False)

    app.launch()