larkkin commited on
Commit
3010bd0
·
1 Parent(s): 7daaa6b

Add model wrapper and gradio app

Browse files
Files changed (2) hide show
  1. app.py +76 -0
  2. model_wrapper.py +111 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import model_wrapper
3
+
4
+
5
+
6
+ model = model_wrapper.PredictionModel()
7
+
8
+ def pretty_print_opinion(opinion_dict):
9
+ res = []
10
+ maxlen = max([len(key) for key in opinion_dict.keys()]) + 2
11
+ maxlen = 0
12
+ for key, value in opinion_dict.items():
13
+ if key == 'Polarity':
14
+ res.append(f'{(key + ":").ljust(maxlen)} {value}')
15
+ else:
16
+ res.append(f'{(key + ":").ljust(maxlen)} \'{" ".join(value[0])}\'')
17
+ return '\n'.join(res) + '\n'
18
+
19
+
20
+ def predict(text):
21
+ predictions = model.predict([text])
22
+ prediction = predictions[0]
23
+ results = []
24
+ if not prediction['opinions']:
25
+ return 'No opinions detected'
26
+ for opinion in prediction['opinions']:
27
+ results.append(pretty_print_opinion(opinion))
28
+
29
+ return '\n'.join(results)
30
+
31
+
32
+
33
+ markdown_text = '''
34
+ <br>
35
+ <br>
36
+ This space provides a gradio demo and an easy-to-run wrapper of the pre-trained model for structured sentiment analysis in Norwegian language, pre-trained on the [NoReC dataset](https://huggingface.co/datasets/norec).
37
+ This model is an implementation of the paper "Direct parsing to sentiment graphs" (Samuel _et al._, ACL 2022). The main repository that also contains the scripts for training the model, can be found on the project [github](https://github.com/jerbarnes/direct_parsing_to_sent_graph).
38
+
39
+ The current model uses the 'labeled-edge' graph encoding, and achieves the following results on the NoReC dataset:
40
+
41
+ | Unlabeled sentiment tuple F1 | Target F1 | Relative polarity precision |
42
+ |:----------------------------:|:----------:|:---------------------------:|
43
+ | 0.393 | 0.468 | 0.939 |
44
+
45
+
46
+ The model can be easily used for predicting sentiment tuples as follows:
47
+
48
+ ```python
49
+ >>> import model_wrapper
50
+ >>> model = model_wrapper.PredictionModel()
51
+ >>> model.predict(['vi liker svart kaffe'])
52
+ [{'sent_id': '0',
53
+ 'text': 'vi liker svart kaffe',
54
+ 'opinions': [{'Source': [['vi'], ['0:2']],
55
+ 'Target': [['svart', 'kaffe'], ['9:14', '15:20']],
56
+ 'Polar_expression': [['liker'], ['3:8']],
57
+ 'Polarity': 'Positive'}]}]
58
+ ```
59
+ '''
60
+
61
+
62
+
63
+ with gr.Blocks() as demo:
64
+ with gr.Row(equal_height=False) as row:
65
+ text_input = gr.Textbox(label="input")
66
+ text_output = gr.Textbox(label="output")
67
+ with gr.Row(scale=4) as row:
68
+ text_button = gr.Button("submit").style(full_width=True)
69
+
70
+ text_button.click(fn=predict, inputs=text_input, outputs=text_output)
71
+
72
+ gr.Markdown(markdown_text)
73
+
74
+
75
+
76
+ demo.launch()
model_wrapper.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import sys
5
+ import datetime
6
+ import re
7
+ import string
8
+ sys.path.append('mtool')
9
+
10
+ import torch
11
+
12
+ from model.model import Model
13
+ from data.dataset import Dataset
14
+ from config.params import Params
15
+ from utility.initialize import initialize
16
+ from data.batch import Batch
17
+ from mtool.main import main as mtool_main
18
+
19
+
20
+ from tqdm import tqdm
21
+
22
+ class PredictionModel:
23
+ def __init__(self, checkpoint_path=os.path.join('models', 'checkpoint.bin'), default_mrp_path=os.path.join('models', 'default.mrp'), verbose=False):
24
+ self.verbose = verbose
25
+ self.checkpoint = torch.load('./models/checkpoint.bin', map_location=torch.device('cpu'))
26
+ self.args = Params().load_state_dict(self.checkpoint['params'])
27
+ self.args.log_wandb = False
28
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
29
+
30
+ self.args.training_data = default_mrp_path
31
+ self.args.validation_data = default_mrp_path
32
+ self.args.test_data = default_mrp_path
33
+ self.args.only_train = False
34
+ self.args.encoder = os.path.join('models', 'encoder')
35
+ initialize(self.args, init_wandb=False)
36
+ self.dataset = Dataset(self.args, verbose=False)
37
+ self.model = Model(self.dataset, self.args).to(self.device)
38
+ self.model.load_state_dict(self.checkpoint["model"])
39
+ self.model.eval()
40
+
41
+
42
+ def _mrp_to_text(self, mrp_list, graph_mode='labeled-edge'):
43
+ framework = 'norec'
44
+ with tempfile.NamedTemporaryFile(delete=False, mode='w') as output_text_file:
45
+ output_text_filename = output_text_file.name
46
+
47
+ with tempfile.NamedTemporaryFile(delete=False, mode='w') as mrp_file:
48
+ line = '\n'.join([json.dumps(entry) for entry in mrp_list])
49
+ mrp_file.write(line)
50
+ mrp_filename = mrp_file.name
51
+
52
+ if graph_mode == 'labeled-edge':
53
+ mtool_main([
54
+ '--strings',
55
+ '--ids',
56
+ '--read', 'mrp',
57
+ '--write', framework,
58
+ mrp_filename, output_text_filename
59
+ ])
60
+ elif graph_mode == 'node-centric':
61
+ mtool_main([
62
+ '--node_centric',
63
+ '--strings',
64
+ '--ids',
65
+ '--read', 'mrp',
66
+ '--write', framework,
67
+ mrp_filename, output_text_filename
68
+ ])
69
+ else:
70
+ raise Exception(f'Unknown graph mode: {graph_mode}')
71
+
72
+ with open(output_text_filename) as f:
73
+ texts = json.load(f)
74
+
75
+ os.unlink(output_text_filename)
76
+ os.unlink(mrp_filename)
77
+
78
+ return texts
79
+
80
+
81
+ def clean_texts(self, texts):
82
+ punctuation = ''.join([f'\\{s}' for s in string.punctuation])
83
+ texts = [re.sub(f'([{punctuation}])', ' \\1 ', t) for t in texts]
84
+ texts = [re.sub(r' +', ' ', t) for t in texts]
85
+ return texts
86
+
87
+
88
+ def _predict_to_mrp(self, texts, graph_mode='labeled-edge'):
89
+ texts = self.clean_texts(texts)
90
+ framework, language = self.args.framework, self.args.language
91
+ data = self.dataset.load_sentences(texts, self.args)
92
+ res_sentences = {f"{i}": {'input': sentence} for i, sentence in enumerate(texts)}
93
+ date_str = datetime.datetime.now().date().isoformat()
94
+ for key, value_dict in res_sentences.items():
95
+ value_dict['id'] = key
96
+ value_dict['time'] = date_str
97
+ value_dict['framework'], value_dict['language'] = framework, language
98
+ value_dict['nodes'], value_dict['edges'], value_dict['tops'] = [], [], []
99
+ for i, batch in enumerate(tqdm(data) if self.verbose else data):
100
+ with torch.no_grad():
101
+ predictions = self.model(Batch.to(batch, self.device), inference=True)
102
+ for prediction in predictions:
103
+ for key, value in prediction.items():
104
+ res_sentences[prediction['id']][key] = value
105
+ return res_sentences
106
+
107
+
108
+ def predict(self, text_list, graph_mode='labeled-edge', language='no'):
109
+ mrp_predictions = self._predict_to_mrp(text_list, graph_mode)
110
+ predictions = self._mrp_to_text(mrp_predictions.values(), graph_mode)
111
+ return predictions