KawshikManikantan commited on
Commit
98e2ea5
1 Parent(s): 9da50ee

upload_trial

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +14 -0
  3. __init__.py +0 -0
  4. app.py +119 -0
  5. conf/config.yaml +37 -0
  6. conf/datasets/aft.yaml +5 -0
  7. conf/datasets/aft_increase.yaml +5 -0
  8. conf/datasets/aft_ind.yaml +5 -0
  9. conf/datasets/all.yaml +11 -0
  10. conf/datasets/animal.yaml +5 -0
  11. conf/datasets/avengers.yaml +5 -0
  12. conf/datasets/fantasy.yaml +7 -0
  13. conf/datasets/joint_lf.yaml +3 -0
  14. conf/datasets/litbank.yaml +8 -0
  15. conf/datasets/movie.yaml +5 -0
  16. conf/datasets/movie_cased.yaml +5 -0
  17. conf/datasets/ontonotes.yaml +12 -0
  18. conf/datasets/preco.yaml +6 -0
  19. conf/experiment/eval_all.yaml +13 -0
  20. conf/experiment/lf_coref_id.yaml +26 -0
  21. conf/experiment/lf_eval.yaml +23 -0
  22. conf/experiment/lf_extment.yaml +31 -0
  23. conf/experiment/lf_hybrid.yaml +25 -0
  24. conf/experiment/lf_static.yaml +25 -0
  25. conf/experiment/litbank.yaml +21 -0
  26. conf/experiment/onto_pseudo_hybrid.yaml +29 -0
  27. conf/experiment/onto_pseudo_static.yaml +29 -0
  28. conf/experiment/ontonotes.yaml +17 -0
  29. conf/experiment/ontonotes_pseudo.yaml +27 -0
  30. conf/infra/local.yaml +8 -0
  31. conf/infra/slurm.yaml +10 -0
  32. conf/model/doc_encoder/transformer/longformer_large.yaml +5 -0
  33. conf/model/doc_encoder/transformer_encoder.yaml +10 -0
  34. conf/model/memory/mem_type/unbounded.yaml +3 -0
  35. conf/model/memory/memory.yaml +15 -0
  36. conf/model/model.yaml +24 -0
  37. conf/optimizer/adam.yaml +4 -0
  38. conf/trainer/train.yaml +13 -0
  39. configs.py +4 -0
  40. coref_utils/__init__.py +0 -0
  41. coref_utils/conll.py +126 -0
  42. coref_utils/metrics.py +198 -0
  43. coref_utils/utils.py +43 -0
  44. data_utils/__init__.py +0 -0
  45. data_utils/tensorize_dataset.py +76 -0
  46. data_utils/utils.py +95 -0
  47. error_analysis/__init__.py +0 -0
  48. error_analysis/missing_clusters.py +99 -0
  49. error_analysis/singleton_analysis.py +120 -0
  50. experiment.py +1052 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/met_joint_f78b0fa9c1d7718b9ed703ddcf621ec9_lf_sd_train_gen_4/ filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models/
2
+ models_orig/
3
+ baseline_src/wandb
4
+ data/raw_data
5
+ **/wandb/
6
+ **/trash/
7
+ **/.env
8
+ **/__pycache__/
9
+ **/.hydra/
10
+ **/*result*.jsonl
11
+ **/*nohup.out**
12
+ **/extras/
13
+ models_7_6_24/**
14
+ results_old/**
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import spacy
3
+ import json
4
+
5
+ import gradio as gr
6
+ from spacy.tokens import Doc, Span
7
+ from spacy import displacy
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.colors import to_hex
10
+
11
+ from inference.model_inference import Inference
12
+ from configs import *
13
+
14
+
15
+ def get_MEIRa_clusters(doc_name, text, model_type):
16
+ model_str = MODELS[model_type]
17
+ model = Inference(model_str)
18
+
19
+ output_dict = model.perform_coreference(text, doc_name)
20
+ return output_dict
21
+
22
+
23
+ def coref_visualizer(doc_name, text, model_type):
24
+ coref_output = get_MEIRa_clusters(doc_name, text, model_type)
25
+
26
+ tokens = coref_output["tokenized_doc"]
27
+ clusters = coref_output["clusters"]
28
+ labels = coref_output["representative_names"]
29
+
30
+ ## Get a pastel palette
31
+ color_palette = {
32
+ label: to_hex(plt.cm.get_cmap("tab20", len(labels))(i))
33
+ for i, label in enumerate(labels)
34
+ }
35
+
36
+ nlp = spacy.blank("en")
37
+ doc = Doc(nlp.vocab, words=tokens)
38
+
39
+ print("Tokens:", tokens, flush=True)
40
+ # print("Doc:", doc, flush=True)
41
+ print(color_palette)
42
+
43
+ spans = []
44
+ for cluster_ind, cluster in enumerate(clusters[:-1]):
45
+ label = labels[cluster_ind]
46
+ for (start, end), mention in cluster:
47
+ span = Span(doc, start, end + 1, label=label)
48
+ spans.append(span)
49
+
50
+ doc.spans["coref_spans"] = spans
51
+
52
+ print("Rendering the visualization...")
53
+
54
+ # color_map = {label: color_palette[i] for i, label in enumerate(labels)}
55
+ # Generate the HTML output
56
+ html = displacy.render(
57
+ doc,
58
+ style="span",
59
+ options={
60
+ "spans_key": "coref_spans",
61
+ "colors": color_palette,
62
+ },
63
+ jupyter=False,
64
+ )
65
+
66
+ ## Create a hash based on time and doc_name
67
+ time_hash = hash(str(time.time()) + doc_name)
68
+ html_file = f"gradio_outputs/output_{time_hash}.html"
69
+ json_file = f"gradio_outputs/output_{time_hash}.json"
70
+
71
+ with open(html_file, "w") as f:
72
+ f.write(html)
73
+
74
+ with open(json_file, "w") as f:
75
+ json.dump(coref_output, f)
76
+
77
+ return (
78
+ html_file,
79
+ json_file,
80
+ gr.DownloadButton(value=html_file, visible=True),
81
+ gr.DownloadButton(value=json_file, visible=True),
82
+ )
83
+
84
+
85
+ def download_html():
86
+ return gr.DownloadButton(visible=False)
87
+
88
+
89
+ def download_json():
90
+ return gr.DownloadButton(visible=False)
91
+
92
+
93
+ options = ["static", "hybrid"]
94
+
95
+ with gr.Blocks() as demo:
96
+ html_file = gr.File(visible=False)
97
+ json_file = gr.File(visible=False)
98
+ html_button = gr.DownloadButton("Download HTML", visible=False)
99
+ json_button = gr.DownloadButton("Download JSON", visible=False)
100
+
101
+ html_button.click()
102
+ json_button.click()
103
+ iface = gr.Interface(
104
+ fn=coref_visualizer,
105
+ inputs=[
106
+ gr.Textbox(lines=1, placeholder="Enter document name:"),
107
+ gr.Textbox(lines=100, placeholder="Enter text for coreference resolution:"),
108
+ gr.Radio(choices=options, label="Select an Option"),
109
+ ],
110
+ outputs=[
111
+ html_file,
112
+ json_file,
113
+ html_button,
114
+ json_button,
115
+ ],
116
+ title="Coreference Resolution Visualizer",
117
+ )
118
+
119
+ demo.launch(debug=True)
conf/config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metrics: ['MUC', 'Bcub', 'CEAFE']
2
+ keep_singletons: True
3
+ seed: 45
4
+ train: True
5
+ use_wandb: True
6
+ desc: "Major Entity Tracking"
7
+ override_encoder: False
8
+ log_vals: False
9
+ # Useful for testing models with different memory architecture than the one trained on
10
+ override_memory: False
11
+ log_dir_add: ""
12
+ device: "cuda:0"
13
+ key: ""
14
+
15
+ defaults:
16
+ - _self_
17
+ - datasets: litbank
18
+ - model: model
19
+ - optimizer: adam
20
+ - trainer: train
21
+ - infra: local
22
+ - experiment: debug
23
+
24
+ paths:
25
+ resource_dir: "../data/"
26
+ base_data_dir: ${paths.resource_dir}/raw_data
27
+ conll_scorer: ${paths.resource_dir}/reference-coreference-scorers/scorer.pl
28
+ base_model_dir: ${infra.work_dir}/../models ## remove /../
29
+ model_dir: null
30
+ best_model_dir: null
31
+ model_filename: 'model.pth'
32
+ model_name: null
33
+ model_name_prefix: 'met_'
34
+ model_path: null
35
+ best_model_path: null
36
+ doc_encoder_dirname: 'doc_encoder'
37
+
conf/datasets/aft.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ aft:
2
+ name: "aft"
3
+ targeted_eval: False
4
+ num_test_docs: 3
5
+ has_conll: True
conf/datasets/aft_increase.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ aft_increase:
2
+ name: "aft_increase"
3
+ targeted_eval: False
4
+ num_test_docs: 23
5
+ has_conll: False
conf/datasets/aft_ind.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ aft_ind:
2
+ name: "aft_ind"
3
+ targeted_eval: False
4
+ num_test_docs: 23
5
+ has_conll: False
conf/datasets/all.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - litbank
3
+ - fantasy
4
+ - aft
5
+ - aft_increase
6
+ - aft_ind
7
+ - animal
8
+ - pride
9
+ - movie
10
+ - movie_cased
11
+ - avengers
conf/datasets/animal.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ animal:
2
+ name: "animal"
3
+ targeted_eval: False
4
+ num_test_docs: 3
5
+ has_conll: True
conf/datasets/avengers.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ avengers:
2
+ name: "avengers"
3
+ targeted_eval: False
4
+ num_test_docs: 1
5
+ has_conll: True
conf/datasets/fantasy.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fantasy:
2
+ name: "fantasy"
3
+ targeted_eval: False
4
+ num_train_docs: 171
5
+ num_dev_docs: 20
6
+ num_test_docs: 20
7
+ has_conll: True
conf/datasets/joint_lf.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ defaults:
2
+ - litbank
3
+ - fantasy
conf/datasets/litbank.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ litbank:
2
+ name: "LitBank"
3
+ cross_val_split: 0
4
+ targeted_eval: False
5
+ num_train_docs: 80
6
+ num_dev_docs: 10
7
+ num_test_docs: 10
8
+ has_conll: True
conf/datasets/movie.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ movie:
2
+ name: "movie"
3
+ targeted_eval: False
4
+ num_test_docs: 6
5
+ has_conll: False
conf/datasets/movie_cased.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ movie_cased:
2
+ name: "movie_cased"
3
+ targeted_eval: False
4
+ num_test_docs: 6
5
+ has_conll: False
conf/datasets/ontonotes.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ontonotes:
2
+ name: "OntoNotes"
3
+ targeted_eval: False
4
+ num_train_docs: 2802
5
+ num_dev_docs: 343
6
+ num_test_docs: 348
7
+ has_conll: True
8
+ # OntoNotes specific attributes
9
+ # use_genre_feature: False # Whether to use document genre as a feature or not
10
+ # default_genre: "nw"
11
+ # genres: [ "bc", "bn", "mz", "nw", "pt", "tc", "wb" ]
12
+ singleton_file: null # File path with pseudo-singletons
conf/datasets/preco.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ preco:
2
+ name: "PreCo"
3
+ targeted_eval: False
4
+ num_train_docs: 3000
5
+ num_dev_docs: 500
6
+ num_test_docs: 500
conf/experiment/eval_all.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Evaluate all models
4
+
5
+
6
+ defaults:
7
+ - override /datasets: all
8
+
9
+ model:
10
+ doc_encoder:
11
+ add_speaker_tokens: True
12
+
13
+ train: False
conf/experiment/lf_coref_id.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+
6
+ # Model name in CRAC 2021: longdoc^S Joint
7
+
8
+
9
+ defaults:
10
+ - override /datasets: joint_lf
11
+ - override /trainer: train.yaml
12
+ - override /model: model.yaml
13
+
14
+ trainer:
15
+ log_frequency: 500
16
+ max_evals: 20
17
+ eval_per_k_steps: null
18
+ patience: 10
19
+
20
+ model:
21
+ doc_encoder:
22
+ add_speaker_tokens: True
23
+ memory:
24
+ pseudo_dist: False
25
+
26
+ log_dir_add: "coref_id"
conf/experiment/lf_eval.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+
6
+ # Model name in CRAC 2021: longdoc^S Joint
7
+
8
+
9
+ defaults:
10
+ - override /datasets: joint_lf
11
+ - override /trainer: train.yaml
12
+ - override /model: model.yaml
13
+
14
+ trainer:
15
+ log_frequency: 500
16
+ max_evals: 25
17
+ eval_per_k_steps: null
18
+ patience: 10
19
+
20
+ model:
21
+ doc_encoder:
22
+ add_speaker_tokens: True
23
+
conf/experiment/lf_extment.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+
6
+ # Model name in CRAC 2021: longdoc^S Joint
7
+
8
+
9
+ defaults:
10
+ - override /datasets: joint_lf
11
+ - override /trainer: train.yaml
12
+ - override /model: model.yaml
13
+
14
+ trainer:
15
+ log_frequency: 500
16
+ max_evals: 20
17
+ eval_per_k_steps: null
18
+ patience: 10
19
+
20
+ datasets:
21
+ litbank:
22
+ external_md_file: "litbank/longformer_speaker/0/mentions_ment_model_litbank_eval.jsonl"
23
+ fantasy:
24
+ external_md_file: "fantasy/longformer_speaker/mentions_ment_model_fantasy_eval.jsonl"
25
+
26
+ model:
27
+ doc_encoder:
28
+ add_speaker_tokens: True
29
+ mention_params:
30
+ ext_ment: True
31
+
conf/experiment/lf_hybrid.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+
6
+ # Model name in CRAC 2021: longdoc^S Joint
7
+
8
+
9
+ defaults:
10
+ - override /datasets: joint_lf
11
+ - override /trainer: train.yaml
12
+ - override /model: model.yaml
13
+
14
+ trainer:
15
+ log_frequency: 500
16
+ max_evals: 25
17
+ eval_per_k_steps: null
18
+ patience: 10
19
+
20
+ model:
21
+ doc_encoder:
22
+ add_speaker_tokens: True
23
+ memory:
24
+ type: hybrid
25
+
conf/experiment/lf_static.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+ # OntoNotes also uses pseudo-singletons
6
+
7
+ # Model name in CRAC 2021: longdoc^S Joint + PS 30K
8
+
9
+
10
+ defaults:
11
+ - override /datasets: joint_lf
12
+ - override /trainer: train.yaml
13
+ - override /model: model.yaml
14
+
15
+ trainer:
16
+ log_frequency: 500
17
+ max_evals: 25
18
+ eval_per_k_steps: null
19
+ patience: 10
20
+
21
+ model:
22
+ doc_encoder:
23
+ add_speaker_tokens: True
24
+ memory:
25
+ type: "static"
conf/experiment/litbank.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Vanilla LitBank configuration
4
+
5
+ # Model name in CRAC 2021: longdoc LB_0
6
+
7
+
8
+ defaults:
9
+ - override /datasets: litbank
10
+ - override /trainer: train.yaml
11
+
12
+ trainer:
13
+ log_frequency: 10
14
+ max_evals: 40
15
+ patience: 20
16
+ eval_per_k_steps: null
17
+
18
+ model:
19
+ doc_encoder:
20
+ add_speaker_tokens: True
21
+
conf/experiment/onto_pseudo_hybrid.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+
6
+ # Model name in CRAC 2021: longdoc^S Joint
7
+
8
+
9
+ defaults:
10
+ - override /datasets: ontonotes
11
+ - override /trainer: train.yaml
12
+ - override /model: model.yaml
13
+
14
+ trainer:
15
+ log_frequency: 250
16
+ max_evals: 20
17
+ eval_per_k_steps: null
18
+ patience: 10
19
+
20
+ model:
21
+ doc_encoder:
22
+ add_speaker_tokens: True
23
+ memory:
24
+ type: hybrid
25
+
26
+ datasets:
27
+ ontonotes:
28
+ singleton_file: ontonotes/ment_singletons_longformer_speaker/30.jsonlines
29
+
conf/experiment/onto_pseudo_static.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains a joint model for Ontonotes, Litbank, and Preco.
4
+ # Note that OntoNotes and Preco are downsampled in this configuration.
5
+ # OntoNotes also uses pseudo-singletons
6
+
7
+ # Model name in CRAC 2021: longdoc^S Joint + PS 30K
8
+
9
+
10
+ defaults:
11
+ - override /datasets: ontonotes
12
+ - override /trainer: train.yaml
13
+ - override /model: model.yaml
14
+
15
+ trainer:
16
+ log_frequency: 500
17
+ max_evals: 20
18
+ eval_per_k_steps: null
19
+ patience: 10
20
+
21
+ model:
22
+ doc_encoder:
23
+ add_speaker_tokens: True
24
+ memory:
25
+ type: "static"
26
+
27
+ datasets:
28
+ ontonotes:
29
+ singleton_file: ontonotes/ment_singletons_longformer_speaker/30.jsonlines
conf/experiment/ontonotes.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # Vanilla ontonotes configuration which doesn't assume any upstream features
4
+ # of speaker and document genre
5
+
6
+ # Model name in CRAC 2021: longdoc ON
7
+
8
+ defaults:
9
+ - override /datasets: ontonotes
10
+ - override /trainer: train.yaml
11
+
12
+ trainer:
13
+ log_frequency: 250
14
+ patience: 10
15
+ eval_per_k_steps: 5000
16
+
17
+
conf/experiment/ontonotes_pseudo.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # This configuration trains ontonotes using the speaker information and pseudo singletons.
4
+ # This is the best OntoNotes configuration in our CRAC 2021 work.
5
+ # Note that this configuration doesn't assume other upstream features such as document genre
6
+
7
+ # Model name in CRAC 2021: longdoc^S ON + PS 60K
8
+
9
+
10
+ defaults:
11
+ - override /datasets: ontonotes
12
+ - override /trainer: train.yaml
13
+ - override /model: model.yaml
14
+
15
+ trainer:
16
+ log_frequency: 250
17
+ max_evals: 20
18
+ patience: 10
19
+ eval_per_k_steps: null
20
+
21
+ model:
22
+ doc_encoder:
23
+ add_speaker_tokens: True
24
+
25
+ datasets:
26
+ ontonotes:
27
+ singleton_file: ontonotes/ment_singletons_longformer_speaker/30.jsonlines
conf/infra/local.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ is_local: True
2
+ work_dir: "./"
3
+
4
+
5
+ #hydra:
6
+ # run:
7
+ # dir:
8
+ # "~/Research/fast-coref/models"
conf/infra/slurm.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ is_local: False
2
+ job_time: 14280
3
+ job_id: null
4
+ work_dir: "./"
5
+
6
+
7
+ #hydra:
8
+ # run:
9
+ # dir:
10
+ # /share/data/speech/shtoshni/research/fast-coref/slurm_scripts/thesis/${job_id}.log
conf/model/doc_encoder/transformer/longformer_large.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: 'longformer'
2
+ model_size: 'large'
3
+ model_str: 'allenai/longformer-large-4096'
4
+ max_encoder_segment_len: 4096
5
+ max_segment_len: 4096
conf/model/doc_encoder/transformer_encoder.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - transformer: longformer_large
3
+
4
+ chunking: independent
5
+ finetune: true # Add logic of finetuning depending on the training logic
6
+ add_speaker_tokens: true # Change this value depending on the dataset
7
+ speaker_start: '[SPEAKER_START]'
8
+ speaker_end: '[SPEAKER_END]'
9
+
10
+
conf/model/memory/mem_type/unbounded.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: unbounded
2
+ max_ents: null
3
+ eval_max_ents: null
conf/model/memory/memory.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - mem_type: unbounded
3
+
4
+ emb_size: 20
5
+ mlp_size: 3000
6
+ mlp_depth: 1
7
+ sim_func: hadamard
8
+ entity_rep: wt_avg
9
+ num_feats: 2 ## Change this to remove position information.
10
+ thresh: 0.0
11
+ rep_pos: "learned"
12
+ pseudo_dist: True
13
+ num_embeds: 10
14
+ type: "dyn"
15
+ batch_size: 64
conf/model/model.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - doc_encoder: transformer_encoder
3
+ - memory: memory
4
+
5
+ mention_params:
6
+ max_span_width: 20
7
+ ment_emb: attn
8
+ use_gold_ments: false
9
+ ext_ment: false
10
+ use_topk: false
11
+ top_span_ratio: 0.4
12
+ emb_size: 20
13
+ mlp_size: 3000
14
+ mlp_depth: 1
15
+ ment_emb_to_size_factor:
16
+ attn: 3
17
+ endpoint: 2
18
+ max: 1
19
+ ignore_non_gold: True
20
+
21
+ metadata_params:
22
+ use_genre_feature: False
23
+ default_genre: "nw"
24
+ genres: [ "bc", "bn", "mz", "nw", "pt", "tc", "wb" ]
conf/optimizer/adam.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ init_lr: 3e-4
2
+ fine_tune_lr: 1e-5
3
+ max_gradient_norm: 1.0
4
+ lr_decay: linear
conf/trainer/train.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dropout_rate: 0.3
2
+ label_smoothing_wt: 0.1
3
+ ment_loss_mode: 'all'
4
+ normalize_loss: False
5
+ ment_loss_incl: True
6
+ max_evals: 20
7
+ to_save_model: False
8
+ log_frequency: 500
9
+ patience: 10
10
+ eval_per_k_steps: null
11
+ num_training_steps: null
12
+ max_training_segments: 1
13
+ generalise: True
configs.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ MODELS = {
2
+ "static": "models/met_joint_efbc65248a6aedce066a04a9f4f40084_lf_s_train_gen_5",
3
+ "hybrid": "models/met_joint_f78b0fa9c1d7718b9ed703ddcf621ec9_lf_sd_train_gen_4",
4
+ }
coref_utils/__init__.py ADDED
File without changes
coref_utils/conll.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import subprocess
3
+ import operator
4
+ import collections
5
+
6
+ BEGIN_DOCUMENT_REGEX = re.compile(r"#begin document \((.*)\); part (\d+)")
7
+ COREF_RESULTS_REGEX = re.compile(
8
+ r".*Coreference: Recall: \([0-9.]+ / [0-9.]+\) ([0-9.]+)%\tPrecision: \([0-9.]+ / [0-9.]+\) "
9
+ r"([0-9.]+)%\tF1: ([0-9.]+)%.*",
10
+ re.DOTALL,
11
+ )
12
+
13
+
14
+ def get_doc_key(doc_id, part):
15
+ return "{}_{}".format(doc_id, int(part))
16
+
17
+
18
+ def output_conll(input_file, output_file, predictions, subtoken_map):
19
+ prediction_map = {}
20
+ for doc_key, clusters in predictions.items():
21
+ start_map = collections.defaultdict(list)
22
+ end_map = collections.defaultdict(list)
23
+ word_map = collections.defaultdict(list)
24
+ for cluster_id, mentions in enumerate(clusters):
25
+ for start, end in mentions:
26
+ start, end = subtoken_map[doc_key][start], subtoken_map[doc_key][end]
27
+ if start == end:
28
+ word_map[start].append(cluster_id)
29
+ else:
30
+ start_map[start].append((cluster_id, end))
31
+ end_map[end].append((cluster_id, start))
32
+ for k, v in start_map.items():
33
+ start_map[k] = [
34
+ cluster_id
35
+ for cluster_id, end in sorted(
36
+ v, key=operator.itemgetter(1), reverse=True
37
+ )
38
+ ]
39
+ for k, v in end_map.items():
40
+ end_map[k] = [
41
+ cluster_id
42
+ for cluster_id, start in sorted(
43
+ v, key=operator.itemgetter(1), reverse=True
44
+ )
45
+ ]
46
+ prediction_map[doc_key] = (start_map, end_map, word_map)
47
+
48
+ word_index = 0
49
+ for line in input_file.readlines():
50
+ row = line.split()
51
+ if len(row) == 0:
52
+ output_file.write("\n")
53
+ elif row[0].startswith("#"):
54
+ begin_match = re.match(BEGIN_DOCUMENT_REGEX, line)
55
+ if begin_match:
56
+ doc_key = get_doc_key(begin_match.group(1), begin_match.group(2))
57
+ start_map, end_map, word_map = prediction_map[doc_key]
58
+ word_index = 0
59
+ output_file.write(line)
60
+ # output_file.write("\n")
61
+ else:
62
+ assert get_doc_key(row[0], row[1]) == doc_key
63
+ coref_list = []
64
+ if word_index in end_map:
65
+ for cluster_id in end_map[word_index]:
66
+ coref_list.append("{})".format(cluster_id))
67
+ if word_index in word_map:
68
+ for cluster_id in word_map[word_index]:
69
+ coref_list.append("({})".format(cluster_id))
70
+ if word_index in start_map:
71
+ for cluster_id in start_map[word_index]:
72
+ coref_list.append("({}".format(cluster_id))
73
+
74
+ if len(coref_list) == 0:
75
+ row[-1] = "-"
76
+ else:
77
+ row[-1] = "|".join(coref_list)
78
+
79
+ output_file.write(" ".join(row))
80
+ output_file.write("\n")
81
+ word_index += 1
82
+
83
+
84
+ def official_conll_eval(
85
+ conll_scorer, gold_path, predicted_path, metric, official_stdout=False
86
+ ):
87
+ cmd = [conll_scorer, metric, gold_path, predicted_path, "none"]
88
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
89
+ stdout, stderr = process.communicate()
90
+ process.wait()
91
+
92
+ stdout = stdout.decode("utf-8")
93
+ if stderr is not None:
94
+ print(stderr)
95
+
96
+ if official_stdout:
97
+ print("Official result for {}".format(metric))
98
+ print(stdout)
99
+
100
+ coref_results_match = re.match(COREF_RESULTS_REGEX, stdout)
101
+ recall = float(coref_results_match.group(1))
102
+ precision = float(coref_results_match.group(2))
103
+ f1 = float(coref_results_match.group(3))
104
+ return {"r": recall, "p": precision, "f": f1}
105
+
106
+
107
+ def evaluate_conll(
108
+ conll_scorer,
109
+ gold_path,
110
+ predictions,
111
+ subtoken_maps,
112
+ prediction_path,
113
+ all_metrics=False,
114
+ official_stdout=False,
115
+ ):
116
+ with open(prediction_path, "w") as prediction_file:
117
+ with open(gold_path, "r") as gold_file:
118
+ output_conll(gold_file, prediction_file, predictions, subtoken_maps)
119
+
120
+ result = {
121
+ metric: official_conll_eval(
122
+ conll_scorer, gold_file.name, prediction_file.name, metric, official_stdout
123
+ )
124
+ for metric in ("muc", "bcub", "ceafe")
125
+ }
126
+ return result
coref_utils/metrics.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from collections import Counter
3
+ from scipy.optimize import linear_sum_assignment
4
+
5
+
6
+ def f1(p_num, p_den, r_num, r_den, beta=1):
7
+ p = 0 if p_den == 0 else p_num / float(p_den)
8
+ r = 0 if r_den == 0 else r_num / float(r_den)
9
+ return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)
10
+
11
+
12
+ class CorefEvaluator(object):
13
+ def __init__(self):
14
+ self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)]
15
+
16
+ def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
17
+ for e in self.evaluators:
18
+ e.update(predicted, gold, mention_to_predicted, mention_to_gold)
19
+
20
+ def get_f1(self):
21
+ return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators)
22
+
23
+ def get_recall(self):
24
+ return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators)
25
+
26
+ def get_precision(self):
27
+ return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators)
28
+
29
+ def get_prf(self):
30
+ return self.get_precision(), self.get_recall(), self.get_f1()
31
+
32
+
33
+ class F1Evaluator(object):
34
+ def __init__(self):
35
+ self.f1_macro_sum = 0.0
36
+ self.f1_micro_sum = 0.0
37
+ self.macro_support = 0
38
+ self.micro_support = 0
39
+
40
+ def update(self, predicted, gold):
41
+ if gold:
42
+ for cluster_ind, cluster in enumerate(gold):
43
+ predicted_set = set(predicted[cluster_ind])
44
+ correct = set(cluster).intersection(set(predicted_set))
45
+ num_correct = len(correct)
46
+ num_predicted = len(predicted_set)
47
+ num_gt = len(cluster)
48
+ precision = num_correct / num_predicted if num_predicted > 0 else 0
49
+ recall = num_correct / num_gt if num_gt > 0 else 0
50
+ f1_score = (
51
+ 2 * precision * recall / (precision + recall)
52
+ if precision + recall > 0
53
+ else 0
54
+ )
55
+ support_entity_micro = num_gt
56
+ support_entity_macro = 1
57
+ self.f1_macro_sum += f1_score * support_entity_macro
58
+ self.f1_micro_sum += f1_score * support_entity_micro
59
+ self.macro_support += support_entity_macro
60
+ self.micro_support += support_entity_micro
61
+
62
+ def get_numbers(self):
63
+ f1_macro = (
64
+ (self.f1_macro_sum / self.macro_support) * 100
65
+ if self.macro_support > 0
66
+ else 0
67
+ )
68
+ f1_micro = (
69
+ (self.f1_micro_sum / self.micro_support) * 100
70
+ if self.micro_support > 0
71
+ else 0
72
+ )
73
+ return f1_macro, f1_micro
74
+
75
+
76
+ class Evaluator(object):
77
+ def __init__(self, metric, beta=1):
78
+ self.p_num = 0
79
+ self.p_den = 0
80
+ self.r_num = 0
81
+ self.r_den = 0
82
+ self.metric = metric
83
+ self.beta = beta
84
+
85
+ def update(self, predicted, gold, mention_to_predicted, mention_to_gold):
86
+ if self.metric == ceafe:
87
+ pn, pd, rn, rd = self.metric(predicted, gold)
88
+ else:
89
+ pn, pd = self.metric(predicted, mention_to_gold)
90
+ rn, rd = self.metric(gold, mention_to_predicted)
91
+ self.p_num += pn
92
+ self.p_den += pd
93
+ self.r_num += rn
94
+ self.r_den += rd
95
+
96
+ def get_f1(self):
97
+ return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta)
98
+
99
+ def get_recall(self):
100
+ return 0 if self.r_num == 0 else self.r_num / float(self.r_den)
101
+
102
+ def get_precision(self):
103
+ return 0 if self.p_num == 0 else self.p_num / float(self.p_den)
104
+
105
+ def get_prf(self):
106
+ return self.get_precision(), self.get_recall(), self.get_f1()
107
+
108
+ def get_counts(self):
109
+ return self.p_num, self.p_den, self.r_num, self.r_den
110
+
111
+ def get_prf_str(self):
112
+ perf_str = (
113
+ f"Recall: {self.get_recall() * 100}, Precision: {self.get_precision() * 100}, "
114
+ f"F-score: {self.get_f1() * 100}\n"
115
+ )
116
+
117
+ return perf_str
118
+
119
+
120
+ def evaluate_documents(documents, metric, beta=1):
121
+ evaluator = Evaluator(metric, beta=beta)
122
+ for document in documents:
123
+ evaluator.update(document)
124
+ return evaluator.get_precision(), evaluator.get_recall(), evaluator.get_f1()
125
+
126
+
127
+ def b_cubed(clusters, mention_to_gold):
128
+ num, dem = 0, 0
129
+
130
+ for c in clusters:
131
+ gold_counts = Counter()
132
+ correct = 0
133
+ for m in c:
134
+ if m in mention_to_gold:
135
+ gold_counts[tuple(mention_to_gold[m])] += 1
136
+ for c2, count in gold_counts.items():
137
+ correct += count * count
138
+
139
+ num += correct / float(len(c))
140
+ dem += len(c)
141
+
142
+ return num, dem
143
+
144
+
145
+ def muc(clusters, mention_to_gold):
146
+ tp, p = 0, 0
147
+ for c in clusters:
148
+ p += len(c) - 1
149
+ tp += len(c)
150
+ linked = set()
151
+ for m in c:
152
+ if m in mention_to_gold:
153
+ linked.add(mention_to_gold[m])
154
+ else:
155
+ tp -= 1
156
+ tp -= len(linked)
157
+ return tp, p
158
+
159
+
160
+ def phi4(c1, c2):
161
+ return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2))
162
+
163
+
164
+ def ceafe(clusters, gold_clusters):
165
+ scores = np.zeros((len(gold_clusters), len(clusters)))
166
+ for i in range(len(gold_clusters)):
167
+ for j in range(len(clusters)):
168
+ scores[i, j] = phi4(gold_clusters[i], clusters[j])
169
+ matching = linear_sum_assignment(-scores)
170
+ matching = np.asarray(matching)
171
+ matching = np.transpose(matching)
172
+
173
+ similarity = sum(scores[matching[:, 0], matching[:, 1]])
174
+ return similarity, len(clusters), similarity, len(gold_clusters)
175
+
176
+
177
+ def lea(clusters, mention_to_gold):
178
+ num, dem = 0, 0
179
+
180
+ for c in clusters:
181
+ if len(c) == 1:
182
+ continue
183
+
184
+ common_links = 0
185
+ all_links = len(c) * (len(c) - 1) / 2.0
186
+ for i, m in enumerate(c):
187
+ if m in mention_to_gold:
188
+ for m2 in c[i + 1 :]:
189
+ if (
190
+ m2 in mention_to_gold
191
+ and mention_to_gold[m] == mention_to_gold[m2]
192
+ ):
193
+ common_links += 1
194
+
195
+ num += len(c) * common_links / float(all_links)
196
+ dem += len(c)
197
+
198
+ return num, dem
coref_utils/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+
3
+
4
+ def filter_clusters(clusters: List, threshold: int = 1) -> List:
5
+ """Filter clusters with mentions less than the specified threshold."""
6
+
7
+ return [
8
+ tuple(tuple(mention) for mention in cluster)
9
+ for cluster_ind,cluster in enumerate(clusters)
10
+ if len(cluster) >= threshold and cluster_ind != len(clusters) - 1 # last cluster is always removed.
11
+ ]
12
+
13
+
14
+ def get_mention_to_cluster(clusters: List) -> Dict:
15
+ """Get mention to cluster mapping."""
16
+
17
+ clusters = [tuple(tuple(mention) for mention in cluster) for cluster in clusters]
18
+ mention_to_cluster_dict = {}
19
+ for cluster in clusters:
20
+ for mention in cluster:
21
+ mention_to_cluster_dict[mention] = cluster
22
+ return mention_to_cluster_dict
23
+
24
+
25
+ def get_mention_to_cluster_idx(clusters: List) -> Dict:
26
+ """Get mention to cluster idx mapping while filtering clustering."""
27
+
28
+ clusters = [tuple(tuple(mention) for mention in cluster) for cluster in clusters]
29
+ mention_to_cluster_dict = {}
30
+ for cluster_idx, cluster in enumerate(clusters):
31
+ for mention in cluster:
32
+ mention_to_cluster_dict[mention] = cluster_idx
33
+ return mention_to_cluster_dict
34
+
35
+
36
+ def is_aligned(span1: Tuple[int, int], span2: Tuple[int, int]) -> bool:
37
+ """Return true if one of the span is a substring of the other span."""
38
+
39
+ if span1[0] >= span2[0] and span1[1] <= span2[1]:
40
+ return True
41
+ if span2[0] >= span1[0] and span2[1] <= span1[1]:
42
+ return True
43
+ return False
data_utils/__init__.py ADDED
File without changes
data_utils/tensorize_dataset.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Dict, Union
3
+ from transformers import PreTrainedTokenizerFast
4
+ from torch import Tensor
5
+
6
+
7
+ class TensorizeDataset:
8
+ def __init__(
9
+ self, tokenizer: PreTrainedTokenizerFast, remove_singletons: bool = False
10
+ ) -> None:
11
+ self.tokenizer = tokenizer
12
+ self.remove_singletons = remove_singletons
13
+ self.device = torch.device("cpu")
14
+
15
+ def tensorize_data(
16
+ self, split_data: List[Dict], training: bool = False
17
+ ) -> List[Dict]:
18
+ tensorized_data = []
19
+ for document in split_data:
20
+ tensorized_data.append(
21
+ self.tensorize_instance_independent(document, training=training)
22
+ )
23
+
24
+ return tensorized_data
25
+
26
+ def process_segment(self, segment: List) -> List:
27
+ if self.tokenizer.sep_token_id is None:
28
+ # print("SentencePiece Tokenizer")
29
+ return [self.tokenizer.bos_token_id] + segment + [self.tokenizer.eos_token_id]
30
+ else:
31
+ # print("WordPiece Tokenizer")
32
+ return [self.tokenizer.cls_token_id] + segment + [self.tokenizer.sep_token_id]
33
+
34
+ def tensorize_instance_independent(
35
+ self, document: Dict, training: bool = False
36
+ ) -> Dict:
37
+ segments: List[List[int]] = document["sentences"]
38
+ clusters: List = document.get("clusters", [])
39
+ ext_predicted_mentions: List = document.get("ext_predicted_mentions", [])
40
+ sentence_map: List[int] = document["sentence_map"]
41
+ subtoken_map: List[int] = document["subtoken_map"]
42
+ representatives: List = document.get("representatives", [])
43
+ representative_embs: List = document.get("representative_embs", [])
44
+
45
+ tensorized_sent: List[Tensor] = [
46
+ torch.unsqueeze(
47
+ torch.tensor(self.process_segment(sent), device=self.device), dim=0
48
+ )
49
+ for sent in segments
50
+ ]
51
+
52
+ sent_len_list = [len(sent) for sent in segments]
53
+ output_dict = {
54
+ "tensorized_sent": tensorized_sent,
55
+ "sentences": segments,
56
+ "sent_len_list": sent_len_list,
57
+ "doc_key": document.get("doc_key", None),
58
+ "clusters": clusters,
59
+ "ext_predicted_mentions": ext_predicted_mentions,
60
+ "subtoken_map": subtoken_map,
61
+ "sentence_map": torch.tensor(sentence_map, device=self.device),
62
+ "representatives": representatives,
63
+ "representative_embs": representative_embs,
64
+ }
65
+
66
+ # Pass along other metadata
67
+ for key in document:
68
+ if key not in output_dict:
69
+ output_dict[key] = document[key]
70
+
71
+ if self.remove_singletons:
72
+ output_dict["clusters"] = [
73
+ cluster for cluster in output_dict["clusters"] if len(cluster) > 1
74
+ ]
75
+
76
+ return output_dict
data_utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from os import path
3
+ from typing import Dict
4
+ import jsonlines
5
+
6
+
7
+ def get_data_file(data_dir: str, split: str, max_segment_len: int) -> str:
8
+ jsonl_file = path.join(
9
+ data_dir, "{}.{}.met.jsonlines".format(split, max_segment_len)
10
+ )
11
+ print("File access: ", jsonl_file)
12
+ if path.exists(jsonl_file):
13
+ return jsonl_file
14
+ else:
15
+ jsonl_file = path.join(data_dir, "{}.met.jsonlines".format(split))
16
+ if path.exists(jsonl_file):
17
+ return jsonl_file
18
+
19
+
20
+ def load_dataset(
21
+ data_dir: str,
22
+ singleton_file: str = None,
23
+ max_segment_len: int = 2048,
24
+ num_train_docs: int = None,
25
+ num_dev_docs: int = None,
26
+ num_test_docs: int = None,
27
+ dataset_name: str = None,
28
+ ) -> Dict:
29
+ all_splits = []
30
+ for split in ["train", "dev", "test"]:
31
+ jsonl_file = get_data_file(data_dir, split, max_segment_len)
32
+ if jsonl_file is None:
33
+ raise ValueError(f"No relevant files at {data_dir}")
34
+ split_data = []
35
+ with open(jsonl_file) as f:
36
+ for line in f:
37
+ load_dict = json.loads(line.strip())
38
+ load_dict["dataset_name"] = dataset_name
39
+ split_data.append(load_dict)
40
+ all_splits.append(split_data)
41
+
42
+ train_data, dev_data, test_data = all_splits
43
+
44
+ if singleton_file is not None and path.exists(singleton_file):
45
+ num_singletons = 0
46
+ with open(singleton_file) as f:
47
+ singleton_data = json.loads(f.read())
48
+
49
+ for instance in train_data:
50
+ doc_key = instance["doc_key"]
51
+ if doc_key in singleton_data:
52
+ if len(instance["clusters"]) != 0:
53
+ num_singletons += len(singleton_data[doc_key])
54
+ instance["clusters"][-1].extend(
55
+ [cluster[0] for cluster in singleton_data[doc_key]]
56
+ )
57
+
58
+ print("Added %d singletons" % num_singletons)
59
+
60
+ return {
61
+ "train": train_data[:num_train_docs],
62
+ "dev": dev_data[:num_dev_docs],
63
+ "test": test_data[:num_test_docs],
64
+ }
65
+
66
+
67
+ def load_eval_dataset(
68
+ data_dir: str, external_md_file: str, max_segment_len: int, dataset_name: str = None
69
+ ) -> Dict:
70
+ data_dict = {}
71
+ for split in ["dev", "test"]:
72
+ jsonl_file = get_data_file(data_dir, split, max_segment_len)
73
+ if jsonl_file is not None:
74
+ split_data = []
75
+ with open(jsonl_file) as f:
76
+ for line in f:
77
+ load_dict = json.loads(line.strip())
78
+ load_dict["dataset_name"] = dataset_name
79
+ split_data.append(load_dict)
80
+
81
+ data_dict[split] = split_data
82
+
83
+ if external_md_file is not None and path.exists(external_md_file):
84
+ predicted_mentions = {}
85
+ with jsonlines.open(external_md_file, mode="r") as reader:
86
+ for line in reader:
87
+ predicted_mentions[line["doc_key"]] = line
88
+ for split in ["dev", "test"]:
89
+ for instance in data_dict[split]:
90
+ doc_key = instance["doc_key"]
91
+ if doc_key in predicted_mentions:
92
+ instance["ext_predicted_mentions"] = sorted(
93
+ predicted_mentions[doc_key]["pred_mentions"]
94
+ )
95
+ return data_dict
error_analysis/__init__.py ADDED
File without changes
error_analysis/missing_clusters.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import json
5
+ import numpy as np
6
+ from coref_utils.metrics import CorefEvaluator
7
+ from coref_utils.utils import get_mention_to_cluster
8
+
9
+
10
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
11
+ logging.basicConfig(format="%(message)s", level=logging.INFO)
12
+ logger = logging.getLogger()
13
+
14
+
15
+ def process_args():
16
+ """Parse command line arguments."""
17
+ parser = argparse.ArgumentParser()
18
+
19
+ # Add arguments to parser
20
+ parser.add_argument("log_file", help="Log file", type=str)
21
+
22
+ args = parser.parse_args()
23
+ return args
24
+
25
+
26
+ def singleton_analysis(data):
27
+ max_length = 0
28
+ max_doc_id = ""
29
+ max_cluster = []
30
+
31
+ for instance in data:
32
+
33
+ gold_clusters, gold_mentions_to_cluster = get_mention_to_cluster(
34
+ instance["clusters"]
35
+ )
36
+ pred_clusters, pred_mentions_to_cluster = get_mention_to_cluster(
37
+ instance["predicted_clusters"]
38
+ )
39
+
40
+ for cluster in gold_clusters:
41
+ all_mention_unseen = True
42
+ for mention in cluster:
43
+ if mention in pred_mentions_to_cluster:
44
+ all_mention_unseen = False
45
+ break
46
+
47
+ if all_mention_unseen:
48
+ if len(cluster) > max_length:
49
+ max_length = len(cluster)
50
+ max_doc_id = instance["doc_key"]
51
+ max_cluster = cluster
52
+
53
+ print(max_doc_id)
54
+ print(max_length, max_cluster)
55
+
56
+
57
+ def reverse_analysis(data):
58
+ max_length = 0
59
+ max_doc_id = ""
60
+ max_cluster = []
61
+
62
+ for instance in data:
63
+
64
+ gold_clusters, gold_mentions_to_cluster = get_mention_to_cluster(
65
+ instance["clusters"]
66
+ )
67
+ pred_clusters, pred_mentions_to_cluster = get_mention_to_cluster(
68
+ instance["predicted_clusters"]
69
+ )
70
+
71
+ for cluster in pred_clusters:
72
+ all_mention_unseen = True
73
+ for mention in cluster:
74
+ if mention in gold_mentions_to_cluster:
75
+ all_mention_unseen = False
76
+ break
77
+
78
+ if all_mention_unseen:
79
+ if len(cluster) > max_length:
80
+ max_length = len(cluster)
81
+ max_doc_id = instance["doc_key"]
82
+ max_cluster = cluster
83
+
84
+ print(max_doc_id)
85
+ print(max_length, max_cluster)
86
+
87
+
88
+ def main():
89
+ args = process_args()
90
+ data = []
91
+ with open(args.log_file) as f:
92
+ for line in f:
93
+ data.append(json.loads(line))
94
+ singleton_analysis(data)
95
+ reverse_analysis(data)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
error_analysis/singleton_analysis.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import logging
4
+ import json
5
+ import numpy as np
6
+ from coref_utils.metrics import CorefEvaluator
7
+ from coref_utils.utils import get_mention_to_cluster, filter_clusters
8
+
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
10
+ logging.basicConfig(format="%(message)s", level=logging.INFO)
11
+ logger = logging.getLogger()
12
+
13
+
14
+ def process_args():
15
+ """Parse command line arguments."""
16
+ parser = argparse.ArgumentParser()
17
+
18
+ # Add arguments to parser
19
+ parser.add_argument("log_file", help="Log file", type=str)
20
+
21
+ args = parser.parse_args()
22
+ return args
23
+
24
+
25
+ def singleton_analysis(data):
26
+ gold_singletons = 0
27
+ pred_singletons = 0
28
+
29
+ # singleton_evaluator = CorefEvaluator()
30
+ non_singleton_evaluator = CorefEvaluator()
31
+
32
+ gold_cluster_lens = []
33
+ pred_cluster_lens = []
34
+
35
+ overlap_sing = 0
36
+ total_sing = 0
37
+ pred_sing = 0
38
+
39
+ for instance in data:
40
+ # Singleton performance
41
+ gold_clusters = set(
42
+ [tuple(cluster[0]) for cluster in instance["clusters"] if len(cluster) == 1]
43
+ )
44
+ pred_clusters = set(
45
+ [
46
+ tuple(cluster[0])
47
+ for cluster in instance["predicted_clusters"]
48
+ if len(cluster) == 1
49
+ ]
50
+ )
51
+
52
+ total_sing += len(gold_clusters)
53
+ pred_sing += len(pred_clusters)
54
+ overlap_sing += len(gold_clusters.intersection(pred_clusters))
55
+
56
+ gold_singletons += len(gold_clusters)
57
+ pred_singletons += len(pred_clusters)
58
+
59
+ # predicted_clusters, mention_to_predicted = get_mention_to_cluster(pred_clusters, threshold=1)
60
+ # gold_clusters, mention_to_gold = get_mention_to_cluster(gold_clusters, threshold=1)
61
+ # singleton_evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold)
62
+
63
+ # Non-singleton performance
64
+ gold_clusters = filter_clusters(instance["clusters"], threshold=2)
65
+ pred_clusters = filter_clusters(instance["predicted_clusters"], threshold=2)
66
+
67
+ gold_cluster_lens.extend([len(cluster) for cluster in instance["clusters"]])
68
+ pred_cluster_lens.extend(
69
+ [len(cluster) for cluster in instance["predicted_clusters"]]
70
+ )
71
+
72
+ # gold_clusters = filter_clusters(gold_clusters, threshold=1)
73
+ # pred_clusters = filter_clusters(pred_clusters, threshold=1)
74
+
75
+ mention_to_predicted = get_mention_to_cluster(pred_clusters)
76
+ mention_to_gold = get_mention_to_cluster(gold_clusters)
77
+ non_singleton_evaluator.update(
78
+ pred_clusters, gold_clusters, mention_to_predicted, mention_to_gold
79
+ )
80
+
81
+ logger.info(
82
+ "\nGT singletons: %d, Pred singletons: %d\n"
83
+ % (gold_singletons, pred_singletons)
84
+ )
85
+ recall_sing = overlap_sing / total_sing
86
+ pred_sing = overlap_sing / pred_sing
87
+ f_sing = 2 * recall_sing * pred_sing / (recall_sing + pred_sing)
88
+ logger.info(
89
+ f"\nSingletons - Recall: {recall_sing * 100}, Pred: {pred_sing * 100}, "
90
+ f"F1: {f_sing * 100}\n"
91
+ )
92
+ logger.info(
93
+ f"\nNon-singleton cluster lengths, Gold: {np.mean(gold_cluster_lens):.2f}, "
94
+ f"Pred: {np.mean(pred_cluster_lens)}\n"
95
+ )
96
+
97
+ for evaluator, evaluator_str in zip([non_singleton_evaluator], ["Non-singleton"]):
98
+ perf_str = ""
99
+ indv_metrics_list = ["MUC", "BCub", "CEAFE"]
100
+ for indv_metric, indv_evaluator in zip(indv_metrics_list, evaluator.evaluators):
101
+ # perf_str += ", " + indv_metric + ": {:.1f}".format(indv_evaluator.get_f1() * 100)
102
+ perf_str += "{} - {}".format(indv_metric, indv_evaluator.get_prf_str())
103
+
104
+ fscore = evaluator.get_f1() * 100
105
+ perf_str += "{} ".format(fscore)
106
+ perf_str = perf_str.strip(", ")
107
+ logger.info("\n%s\n%s\n" % (evaluator_str, perf_str))
108
+
109
+
110
+ def main():
111
+ args = process_args()
112
+ data = []
113
+ with open(args.log_file) as f:
114
+ for line in f:
115
+ data.append(json.loads(line))
116
+ singleton_analysis(data)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
experiment.py ADDED
@@ -0,0 +1,1052 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import time
4
+ import logging
5
+ import torch
6
+
7
+ ## Uncomment the following line to make the code deterministic and use CUBLAS_WORKSPACE_CONFIG=:4096:8
8
+ torch.use_deterministic_algorithms(True)
9
+ import json
10
+ import numpy as np
11
+ import random
12
+ import wandb
13
+
14
+ from omegaconf import OmegaConf, open_dict
15
+ from os import path
16
+ from collections import OrderedDict, defaultdict
17
+ from transformers import get_linear_schedule_with_warmup
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ from data_utils.utils import load_dataset, load_eval_dataset
21
+ import pytorch_utils.utils as utils
22
+ from torch.profiler import profile, record_function, ProfilerActivity
23
+
24
+ from model.entity_ranking_model import EntityRankingModel
25
+ from model.mention_proposal import MentionProposalModule
26
+ from data_utils.tensorize_dataset import TensorizeDataset
27
+ from pytorch_utils.optimization_utils import get_inverse_square_root_decay
28
+
29
+ from utils_evaluate import coref_evaluation
30
+
31
+ from typing import Dict, Union, List, Optional
32
+ from omegaconf import DictConfig
33
+ import copy
34
+
35
+ logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)
36
+ logger = logging.getLogger()
37
+
38
+ loss_acc_template_dict = {
39
+ "total": 0.0,
40
+ "ment_loss": 0.0,
41
+ "coref": 0.0,
42
+ "mention_count": 0.0,
43
+ "processed_docs": 0.0,
44
+ "ment_correct": 0.0,
45
+ "ment_total": 0.0,
46
+ "ment_tp": 0.0,
47
+ "ment_pp": 0.0,
48
+ "ment_ap": 0.0,
49
+ }
50
+
51
+
52
+ class Experiment:
53
+ """Class for training and evaluating coreference models."""
54
+
55
+ def __init__(self, config: DictConfig):
56
+ self.config = config
57
+
58
+ print("Seeded: ", config.seed)
59
+ print("Cuda Available: ", torch.cuda.is_available())
60
+
61
+ # Whether to train or not
62
+ self.eval_model: bool = not self.config.train
63
+
64
+ # Initialize dictionary to track key training variables
65
+ self.train_info = {
66
+ "val_perf": 0.0,
67
+ "global_steps": 0,
68
+ "num_stuck_evals": 0,
69
+ "peak_memory": 0.0,
70
+ }
71
+
72
+ self.wandbdata = {}
73
+
74
+ # Initialize model path attributes
75
+ self.model_path = self.config.paths.model_path
76
+ self.best_model_path = self.config.paths.best_model_path
77
+
78
+ if not self.eval_model:
79
+ # Step 1 - Initialize model
80
+ self._build_model()
81
+ # Step 2 - Load Data - Data processing choices such as tokenizer will depend on the model
82
+ self._load_data()
83
+ # Step 3 - Resume training
84
+ self._setup_training()
85
+ # Step 4 - Loading the checkpoint also restores the training metadata
86
+ self._load_previous_checkpoint()
87
+
88
+ # All set to resume training
89
+ # But first check if training is remaining
90
+ if self._is_training_remaining():
91
+ self.train()
92
+
93
+ # Perform final evaluation
94
+ if path.exists(self.best_model_path):
95
+ # Step 1 - Initialize model
96
+ self._initialize_best_model()
97
+ # Step 2 - Load evaluation data
98
+ self._load_data()
99
+ # Step 3 - Perform evaluation
100
+ self.perform_final_eval()
101
+ else:
102
+ logger.info("No model accessible!")
103
+ sys.exit(1)
104
+
105
+ def _build_model(self) -> None:
106
+ """Constructs the model with given config."""
107
+
108
+ model_params: DictConfig = self.config.model
109
+ train_config: DictConfig = self.config.trainer
110
+
111
+ self.model = EntityRankingModel(
112
+ model_config=model_params, train_config=train_config
113
+ )
114
+
115
+ if torch.cuda.is_available():
116
+ self.model.cuda(device=self.config.device)
117
+
118
+ # Print model
119
+ utils.print_model_info(self.model)
120
+ sys.stdout.flush()
121
+
122
+ def _load_data(self):
123
+ """Loads and processes the training and evaluation data.
124
+
125
+ Loads the data concerning all the specified datasets for training and eval.
126
+ The first part of this method loads all the data from the preprocessed jsonline files.
127
+ In the second half, the loaded data is tensorized for consumption by the model.
128
+
129
+ Apart from loading and processing the data, the method also populates important
130
+ attributes such as:
131
+ num_train_docs_map (dict): Dictionary to maintain the number of training
132
+ docs per dataset which is useful for implementing sampling in joint training.
133
+ num_training_steps (int): Number of total training steps.
134
+ eval_per_k_steps (int): Number of gradient updates before each evaluation.
135
+ """
136
+
137
+ self.data_iter_map, self.conll_data_dir, self.num_split_docs_map = (
138
+ {},
139
+ {},
140
+ {"train": {}, "dev": {}, "test": {}},
141
+ )
142
+ raw_data_map = {}
143
+
144
+ max_segment_len: int = self.config.model.doc_encoder.transformer.max_segment_len
145
+ model_name: str = self.config.model.doc_encoder.transformer.name
146
+ add_speaker_tokens: bool = self.config.model.doc_encoder.add_speaker_tokens
147
+ base_data_dir: str = path.abspath(self.config.paths.base_data_dir)
148
+
149
+ # Load data
150
+ for dataset_name, attributes in self.config.datasets.items():
151
+ num_train_docs: Optional[int] = attributes.get("num_train_docs", None)
152
+ num_dev_docs: Optional[int] = attributes.get("num_dev_docs", None)
153
+ num_test_docs: Optional[int] = attributes.get("num_test_docs", None)
154
+ singleton_file: Optional[str] = attributes.get("singleton_file", None)
155
+ external_md_file: Optional[str] = attributes.get("external_md_file", None)
156
+
157
+ if singleton_file is not None:
158
+ singleton_file = path.join(base_data_dir, singleton_file)
159
+ if path.exists(singleton_file):
160
+ logger.info(f"Singleton file found: {singleton_file}")
161
+
162
+ if external_md_file is not None:
163
+ external_md_file = path.join(base_data_dir, external_md_file)
164
+ if path.exists(external_md_file):
165
+ logger.info(
166
+ f"External mention detector file found: {external_md_file}"
167
+ )
168
+
169
+ # Data directory is a function of dataset name and tokenizer used
170
+ data_dir = path.join(path.join(base_data_dir, dataset_name), model_name)
171
+ # Check if speaker tokens are added
172
+ if add_speaker_tokens:
173
+ pot_data_dir = path.join(
174
+ path.join(path.join(base_data_dir, dataset_name)),
175
+ model_name + "_speaker",
176
+ )
177
+ if path.exists(pot_data_dir):
178
+ data_dir = pot_data_dir
179
+
180
+ # Datasets such as litbank have cross validation splits
181
+ if attributes.get("cross_val_split", None) is not None:
182
+ data_dir = path.join(data_dir, str(attributes.get("cross_val_split")))
183
+
184
+ logger.info("Data directory: %s" % data_dir)
185
+
186
+ # CoNLL data dir
187
+ if attributes.get("has_conll", False):
188
+ conll_dir = path.join(
189
+ path.join(path.join(base_data_dir, dataset_name)), "conll"
190
+ )
191
+ if attributes.get("cross_val_split", None) is not None:
192
+ # LitBank like datasets have cross validation splits
193
+ conll_dir = path.join(
194
+ conll_dir, str(attributes.get("cross_val_split"))
195
+ )
196
+
197
+ if path.exists(conll_dir):
198
+ self.conll_data_dir[dataset_name] = conll_dir
199
+
200
+ self.num_split_docs_map["train"][dataset_name] = num_train_docs
201
+ self.num_split_docs_map["dev"][dataset_name] = num_dev_docs
202
+ self.num_split_docs_map["test"][dataset_name] = num_test_docs
203
+
204
+ if self.eval_model:
205
+ print("In Eval Model DataLoader")
206
+ raw_data_map[dataset_name] = load_eval_dataset(
207
+ data_dir,
208
+ external_md_file=external_md_file,
209
+ max_segment_len=max_segment_len,
210
+ dataset_name=dataset_name,
211
+ )
212
+ else:
213
+ raw_data_map[dataset_name] = load_dataset(
214
+ data_dir,
215
+ singleton_file=singleton_file,
216
+ num_dev_docs=num_dev_docs,
217
+ num_test_docs=num_test_docs,
218
+ max_segment_len=max_segment_len,
219
+ dataset_name=dataset_name,
220
+ )
221
+
222
+ # Tensorize data
223
+ data_processor = TensorizeDataset(
224
+ self.model.get_tokenizer(),
225
+ remove_singletons=(not self.config.keep_singletons),
226
+ )
227
+
228
+ if self.eval_model:
229
+ for split in ["dev", "test"]:
230
+ self.data_iter_map[split] = {}
231
+
232
+ for dataset in raw_data_map:
233
+ for split in raw_data_map[dataset]:
234
+ self.data_iter_map[split][dataset] = data_processor.tensorize_data(
235
+ raw_data_map[dataset][split], training=False
236
+ )
237
+ else:
238
+ # Training
239
+ for split in ["train", "dev", "test"]:
240
+ self.data_iter_map[split] = {}
241
+ training = split == "train"
242
+ for dataset in raw_data_map:
243
+ self.data_iter_map[split][dataset] = data_processor.tensorize_data(
244
+ raw_data_map[dataset][split], training=training
245
+ )
246
+
247
+ # Estimate number of training steps
248
+ if self.config.trainer.eval_per_k_steps is None:
249
+ # Eval steps is 1 epoch (with subsampling) of all the datasets used in joint training
250
+ self.config.trainer.eval_per_k_steps = sum(
251
+ self.num_split_docs_map["train"].values()
252
+ )
253
+
254
+ self.config.trainer.num_training_steps = (
255
+ self.config.trainer.eval_per_k_steps * self.config.trainer.max_evals
256
+ )
257
+ logger.info(
258
+ f"Number of training steps: {self.config.trainer.num_training_steps}"
259
+ )
260
+
261
+ logger.info(f"Eval per k steps: {self.config.trainer.eval_per_k_steps}")
262
+
263
+ def _load_previous_checkpoint(self):
264
+ """Loads the last checkpoint or best checkpoint."""
265
+
266
+ # Resume training
267
+ print("Model Path: ", self.model_path)
268
+ print("Model Initialised:", torch.cuda.memory_summary(self.config.device))
269
+ if path.exists(self.model_path):
270
+ self.load_model(self.model_path, last_checkpoint=True)
271
+ logger.info("Model loaded\n")
272
+ print(
273
+ "Loaded Model Returned:", torch.cuda.memory_summary(self.config.device)
274
+ )
275
+ else:
276
+ # Starting training
277
+ logger.info("Model initialized\n")
278
+ sys.stdout.flush()
279
+
280
+ def _is_training_remaining(self):
281
+ """Check if training is done or remaining.
282
+
283
+ There are two cases where we don't resume training:
284
+ (a) The dev performance has not improved for the allowed patience parameter number of evaluations.
285
+ (b) Number of gradient updates is already >= Total training steps.
286
+
287
+ Returns:
288
+ bool: If true, we resume training. Otherwise do final evaluation.
289
+ """
290
+
291
+ if self.train_info["num_stuck_evals"] >= self.config.trainer.patience:
292
+ return False
293
+ if self.train_info["global_steps"] >= self.config.trainer.num_training_steps:
294
+ return False
295
+
296
+ return True
297
+
298
+ def _setup_training(self):
299
+ """Initialize optimizer and bookkeeping variables for training."""
300
+
301
+ # Dictionary to track key training variables
302
+ self.train_info = {
303
+ "val_perf": 0.0,
304
+ "global_steps": 0,
305
+ "num_stuck_evals": 0,
306
+ "peak_memory": 0.0,
307
+ "max_mem": 0.0,
308
+ }
309
+
310
+ # Initialize optimizers
311
+ self._initialize_optimizers()
312
+
313
+ def _initialize_optimizers(self):
314
+ """Initialize model + optimizer(s). Check if there's a checkpoint in which case we resume from there."""
315
+
316
+ optimizer_config: DictConfig = self.config.optimizer
317
+ train_config: DictConfig = self.config.trainer
318
+ self.optimizer, self.optim_scheduler = {}, {}
319
+
320
+ if torch.cuda.is_available():
321
+ # Gradient scaler required for mixed precision training
322
+ self.scaler = torch.GradScaler("cuda")
323
+ else:
324
+ self.scaler = None
325
+
326
+ # Optimizer for clustering params
327
+ self.optimizer["mem"] = torch.optim.Adam(
328
+ self.model.get_params()[1], lr=optimizer_config.init_lr, eps=1e-6
329
+ )
330
+
331
+ if optimizer_config.lr_decay == "inv":
332
+ self.optim_scheduler["mem"] = get_inverse_square_root_decay(
333
+ self.optimizer["mem"], num_warmup_steps=0
334
+ )
335
+ else:
336
+ # No warmup steps for model params
337
+ self.optim_scheduler["mem"] = get_linear_schedule_with_warmup(
338
+ self.optimizer["mem"],
339
+ num_warmup_steps=0,
340
+ num_training_steps=train_config.num_training_steps,
341
+ )
342
+
343
+ if self.config.model.doc_encoder.finetune:
344
+ # Optimizer for document encoder
345
+ no_decay = [
346
+ "bias",
347
+ "LayerNorm.weight",
348
+ ] # No weight decay for bias and layernorm weights
349
+ encoder_params = self.model.get_params(named=True)[0]
350
+ grouped_param = [
351
+ {
352
+ "params": [
353
+ p
354
+ for n, p in encoder_params
355
+ if not any(nd in n for nd in no_decay)
356
+ ],
357
+ "lr": optimizer_config.fine_tune_lr,
358
+ "weight_decay": 1e-2,
359
+ },
360
+ {
361
+ "params": [
362
+ p for n, p in encoder_params if any(nd in n for nd in no_decay)
363
+ ],
364
+ "lr": optimizer_config.fine_tune_lr,
365
+ "weight_decay": 0.0,
366
+ },
367
+ ]
368
+
369
+ self.optimizer["doc"] = torch.optim.AdamW(
370
+ grouped_param, lr=optimizer_config.fine_tune_lr, eps=1e-6
371
+ )
372
+
373
+ # Scheduler for document encoder
374
+ num_warmup_steps = int(0.1 * train_config.num_training_steps)
375
+ if optimizer_config.lr_decay == "inv":
376
+ self.optim_scheduler["doc"] = get_inverse_square_root_decay(
377
+ self.optimizer["doc"], num_warmup_steps=num_warmup_steps
378
+ )
379
+ else:
380
+ self.optim_scheduler["doc"] = get_linear_schedule_with_warmup(
381
+ self.optimizer["doc"],
382
+ num_warmup_steps=num_warmup_steps,
383
+ num_training_steps=train_config.num_training_steps,
384
+ )
385
+
386
+ def agg(self, datadepdict):
387
+ agg_dict = defaultdict(float)
388
+ for dataset in datadepdict:
389
+ for key in datadepdict[dataset]:
390
+ agg_dict[key] += datadepdict[dataset][key]
391
+
392
+ agg_dict["loss_norm"] = (
393
+ agg_dict["coref"] / agg_dict["mention_count"]
394
+ + agg_dict["ment_loss"] / agg_dict["ment_total"]
395
+ if agg_dict["mention_count"] > 0
396
+ else 0
397
+ )
398
+ agg_dict["ment_acc"] = agg_dict["ment_correct"] / agg_dict["ment_total"]
399
+ agg_dict["ment_prec"] = (
400
+ agg_dict["ment_tp"] / agg_dict["ment_pp"] if agg_dict["ment_pp"] > 0 else 0
401
+ )
402
+ agg_dict["ment_rec"] = (
403
+ agg_dict["ment_tp"] / agg_dict["ment_ap"] if agg_dict["ment_ap"] > 0 else 0
404
+ )
405
+ agg_dict["ment_f1"] = (
406
+ 2
407
+ * (agg_dict["ment_prec"] * agg_dict["ment_rec"])
408
+ / (agg_dict["ment_prec"] + agg_dict["ment_rec"])
409
+ if (agg_dict["ment_prec"] + agg_dict["ment_rec"]) > 0
410
+ else 0
411
+ )
412
+
413
+ return agg_dict
414
+
415
+ def train(self) -> None:
416
+ """Method for training the model.
417
+
418
+ This method implements the training loop.
419
+ Within the training loop, the model is periodically evaluated on the dev set(s).
420
+ """
421
+
422
+ model, optimizer, scheduler, scaler = (
423
+ self.model,
424
+ self.optimizer,
425
+ self.optim_scheduler,
426
+ self.scaler,
427
+ )
428
+ model.train()
429
+
430
+ optimizer_config, train_config = self.config.optimizer, self.config.trainer
431
+
432
+ start_time = time.time()
433
+ eval_time = {"total_time": 0, "num_evals": 0}
434
+ print("Started Training..")
435
+ while True:
436
+ logger.info("Steps done %d" % (self.train_info["global_steps"]))
437
+
438
+ train_data = self.runtime_load_dataset("train")
439
+ np.random.shuffle(train_data)
440
+ logger.info("Per epoch training steps: %d" % len(train_data))
441
+ logger.info("Per epoch training steps: %d" % len(train_data))
442
+
443
+ encoder_params, task_params = model.get_params()
444
+ stat_per_dataset = defaultdict(
445
+ lambda: copy.deepcopy(loss_acc_template_dict)
446
+ )
447
+ agg_stat = self.agg
448
+
449
+ # Training "epoch" -> May not correspond to actual epoch
450
+ for cur_document in train_data:
451
+
452
+ def handle_example(document: Dict) -> Union[None, float]:
453
+ self.train_info["global_steps"] += 1
454
+ for key in optimizer:
455
+ optimizer[key].zero_grad()
456
+ loss_dict: Dict = model.forward_training(document)
457
+
458
+ total_loss = loss_dict["total"]
459
+
460
+ if total_loss is None or torch.isnan(total_loss):
461
+ print("Problem with Loss. Should not occur often")
462
+ return None
463
+
464
+ total_loss.backward()
465
+
466
+ # Gradient clipping
467
+ try:
468
+ for name_ind, param_group in enumerate(
469
+ [encoder_params, task_params]
470
+ ):
471
+ torch.nn.utils.clip_grad_norm_(
472
+ param_group,
473
+ optimizer_config.max_gradient_norm,
474
+ error_if_nonfinite=True,
475
+ )
476
+ except RuntimeError:
477
+ print("Non Finite Gradient")
478
+ return None
479
+
480
+ for key in optimizer:
481
+ self.wandbdata[key + "_lr"] = scheduler[key].get_last_lr()[0]
482
+
483
+ for key in optimizer:
484
+ optimizer[key].step()
485
+ scheduler[key].step()
486
+
487
+ loss_dict_items = {}
488
+ for key in loss_dict:
489
+ loss_dict_items[key] = loss_dict[key].item()
490
+
491
+ dataset_name = document["dataset_name"]
492
+ # print(f"Total loss {cur_document['doc_key']}: {total_loss.item()}")
493
+
494
+ for key in loss_dict_items:
495
+ stat_per_dataset[dataset_name][key] += loss_dict_items[key]
496
+
497
+ stat_per_dataset[dataset_name]["processed_docs"] += 1
498
+
499
+ return total_loss.item()
500
+
501
+ loss = handle_example(cur_document)
502
+
503
+ if self.train_info["global_steps"] % train_config.log_frequency == 0:
504
+ max_mem = (
505
+ (
506
+ torch.cuda.max_memory_allocated(self.config.device)
507
+ / (1024**3)
508
+ )
509
+ if torch.cuda.is_available()
510
+ else 0.0
511
+ )
512
+ if self.train_info.get("max_mem", 0.0) < max_mem:
513
+ self.train_info["max_mem"] = max_mem
514
+
515
+ if loss is not None:
516
+ logger.info(
517
+ "{} {:.3f} Max mem {:.1f} GB".format(
518
+ cur_document["doc_key"],
519
+ loss,
520
+ max_mem,
521
+ )
522
+ )
523
+ sys.stdout.flush()
524
+ if torch.cuda.is_available():
525
+ torch.cuda.reset_peak_memory_stats()
526
+
527
+ if train_config.eval_per_k_steps and (
528
+ self.train_info["global_steps"] % train_config.eval_per_k_steps == 0
529
+ ):
530
+ print("Eval needs to be done here")
531
+ coref_dict = {}
532
+ print(stat_per_dataset)
533
+ if self.config.use_wandb:
534
+ self._wandb_log(
535
+ split="train",
536
+ stat_per_dataset=stat_per_dataset,
537
+ agg_stat=agg_stat,
538
+ coref_dict=coref_dict,
539
+ step=self.train_info["global_steps"]
540
+ // train_config.eval_per_k_steps,
541
+ )
542
+
543
+ stat_per_dataset = defaultdict(
544
+ lambda: copy.deepcopy(loss_acc_template_dict)
545
+ )
546
+
547
+ macro_fscore = self.periodic_model_eval()
548
+
549
+ model.train()
550
+ # Get elapsed time
551
+ elapsed_time = time.time() - start_time
552
+
553
+ start_time = time.time()
554
+ logger.info(
555
+ "Steps: %d, Micro F1: %.1f, Max Micro F1: %.1f, Time: %.2f"
556
+ % (
557
+ self.train_info["global_steps"],
558
+ macro_fscore,
559
+ self.train_info["val_perf"],
560
+ elapsed_time,
561
+ )
562
+ )
563
+
564
+ # Check stopping criteria
565
+ if not self._is_training_remaining():
566
+ break
567
+
568
+ # Check stopping criteria
569
+ if not self._is_training_remaining():
570
+ break
571
+
572
+ logger.handlers[0].flush()
573
+
574
+ def runtime_load_dataset(self, split):
575
+ # Shuffle and load the training data
576
+ data = []
577
+ for dataset, dataset_data in self.data_iter_map[split].items():
578
+ np.random.shuffle(
579
+ dataset_data
580
+ ) ### Commenting this so that we can have a deterministic training
581
+ if self.num_split_docs_map[split].get(dataset, None) is not None:
582
+ # Subsampling the data - This is useful in joint training
583
+ logger.info(
584
+ f"{dataset}: Subsampled {self.num_split_docs_map[split].get(dataset)}"
585
+ )
586
+ random_indices = range(self.num_split_docs_map[split].get(dataset))
587
+ data += [dataset_data[idx] for idx in random_indices]
588
+ else:
589
+ data += dataset_data
590
+ return data
591
+
592
+ def _wandb_log(self, split, stat_per_dataset, agg_stat, coref_dict, step=None):
593
+ for dataset_name in stat_per_dataset:
594
+ for metric_vals in stat_per_dataset[dataset_name]:
595
+ wandb.log(
596
+ data={
597
+ f"{split}/{dataset_name}/{metric_vals}": stat_per_dataset[
598
+ dataset_name
599
+ ][metric_vals]
600
+ },
601
+ step=step,
602
+ )
603
+ if stat_per_dataset[dataset_name]["mention_count"] > 0.0:
604
+ ment_prec = (
605
+ stat_per_dataset[dataset_name]["ment_tp"]
606
+ / stat_per_dataset[dataset_name]["ment_pp"]
607
+ if stat_per_dataset[dataset_name]["ment_pp"] > 0
608
+ else 0
609
+ )
610
+ ment_rec = (
611
+ stat_per_dataset[dataset_name]["ment_tp"]
612
+ / stat_per_dataset[dataset_name]["ment_ap"]
613
+ if stat_per_dataset[dataset_name]["ment_ap"] > 0
614
+ else 0
615
+ )
616
+ ment_f1 = (
617
+ 2 * (ment_prec * ment_rec) / (ment_prec + ment_rec)
618
+ if (ment_prec + ment_rec) > 0
619
+ else 0
620
+ )
621
+ wandb.log(
622
+ data={
623
+ f"{split}/{dataset_name}/loss_norm": stat_per_dataset[
624
+ dataset_name
625
+ ]["coref"]
626
+ / stat_per_dataset[dataset_name]["mention_count"]
627
+ + stat_per_dataset[dataset_name]["ment_loss"]
628
+ / stat_per_dataset[dataset_name]["ment_total"],
629
+ f"{split}/{dataset_name}/ment_acc": stat_per_dataset[
630
+ dataset_name
631
+ ]["ment_correct"]
632
+ / stat_per_dataset[dataset_name]["ment_total"],
633
+ f"{split}/{dataset_name}/ment_prec": ment_prec,
634
+ f"{split}/{dataset_name}/ment_rec": ment_rec,
635
+ f"{split}/{dataset_name}/ment_f1": ment_f1,
636
+ },
637
+ step=step,
638
+ )
639
+ else:
640
+ print("No mentions processed. Should not occur many times.")
641
+
642
+ if agg_stat:
643
+ for metric in agg_stat(stat_per_dataset):
644
+ wandb.log(
645
+ data={f"{split}/{metric}": agg_stat(stat_per_dataset)[metric]},
646
+ step=step,
647
+ )
648
+
649
+ for dataset in coref_dict:
650
+ for key in coref_dict[dataset]:
651
+ # Log result for individual metrics
652
+ if isinstance(coref_dict[dataset][key], dict):
653
+ wandb.log(
654
+ data={
655
+ f"{split}/{dataset}/{key}": coref_dict[dataset][key].get(
656
+ "fscore", 0.0
657
+ )
658
+ },
659
+ step=step,
660
+ )
661
+
662
+ # Log the overall F-score
663
+ wandb.log(
664
+ data={
665
+ f"{split}/{dataset}/CoNLL": coref_dict[dataset].get("fscore", 0.0)
666
+ },
667
+ step=step,
668
+ )
669
+
670
+ wandb.log(
671
+ data={
672
+ f"{split}/{dataset}/Micro-F1": coref_dict[dataset].get(
673
+ "f1_micro", 0.0
674
+ )
675
+ },
676
+ step=step,
677
+ )
678
+
679
+ wandb.log(
680
+ data={
681
+ f"{split}/{dataset}/Macro-F1": coref_dict[dataset].get(
682
+ "f1_macro", 0.0
683
+ )
684
+ },
685
+ step=step,
686
+ )
687
+
688
+ wandb.log(data=self.wandbdata, step=step)
689
+
690
+ @torch.no_grad()
691
+ def periodic_model_eval(self) -> float:
692
+ """Method for evaluating and saving the model during the training loop.
693
+
694
+ Returns:
695
+ float: Average CoNLL F-score over all the development sets of datasets.
696
+ """
697
+
698
+ self.model.eval()
699
+
700
+ ## Dev Loss Calculations:
701
+ dev_data = self.runtime_load_dataset("dev")
702
+ np.random.shuffle(dev_data)
703
+ stat_per_dataset = defaultdict(lambda: copy.deepcopy(loss_acc_template_dict))
704
+ agg_stat = self.agg
705
+
706
+ for cur_document in dev_data:
707
+
708
+ def handle_example(document: Dict) -> Union[None, float]:
709
+ loss_dict: Dict = self.model.forward_training(document)
710
+ total_loss = loss_dict["total"]
711
+ if total_loss is None or torch.isnan(total_loss):
712
+ print("Problem with Loss. Should not occur many times")
713
+ return None
714
+
715
+ loss_dict_items = {}
716
+ for key in loss_dict:
717
+ loss_dict_items[key] = loss_dict[key].item()
718
+
719
+ dataset_name = document["dataset_name"]
720
+
721
+ for key in loss_dict_items:
722
+ stat_per_dataset[dataset_name][key] += loss_dict_items[key]
723
+
724
+ stat_per_dataset[dataset_name]["processed_docs"] += 1
725
+ return total_loss.item()
726
+
727
+ loss = handle_example(cur_document)
728
+ if loss is None:
729
+ continue
730
+
731
+ # Dev performance
732
+ coref_dict = {}
733
+ train_config = self.config.trainer
734
+ for dataset in self.data_iter_map["dev"]:
735
+ for go in [False]:
736
+ for tf in [False]:
737
+ result_dict = coref_evaluation(
738
+ self.config,
739
+ self.model,
740
+ self.data_iter_map,
741
+ dataset,
742
+ teacher_force=tf,
743
+ gold_mentions=go,
744
+ _iter="_"
745
+ + str(
746
+ self.train_info["global_steps"]
747
+ // train_config.eval_per_k_steps
748
+ ),
749
+ conll_data_dir=self.conll_data_dir,
750
+ )
751
+
752
+ coref_dict[dataset] = result_dict
753
+
754
+ if self.config.use_wandb:
755
+ self._wandb_log(
756
+ split="dev",
757
+ stat_per_dataset=stat_per_dataset,
758
+ agg_stat=agg_stat,
759
+ coref_dict=coref_dict,
760
+ step=self.train_info["global_steps"] // train_config.eval_per_k_steps,
761
+ )
762
+
763
+ # Calculate Mean F-score
764
+ fscore = sum([coref_dict[dataset]["fscore"] for dataset in coref_dict]) / len(
765
+ coref_dict
766
+ )
767
+ micro_fscore = sum(
768
+ [coref_dict[dataset]["f1_micro"] for dataset in coref_dict]
769
+ ) / len(coref_dict)
770
+ macro_fscore = sum(
771
+ [coref_dict[dataset]["f1_macro"] for dataset in coref_dict]
772
+ ) / len(coref_dict)
773
+
774
+ logger.info(
775
+ "Avg Macro F1: %.1f, Max Micro F1: %.1f"
776
+ % (macro_fscore, self.train_info["val_perf"])
777
+ )
778
+ logger.info("Avg Macro F1: %.1f" % (macro_fscore))
779
+
780
+ # Update model if dev performance improves
781
+ if macro_fscore > self.train_info["val_perf"]:
782
+ # Update training bookkeeping variables
783
+ self.train_info["num_stuck_evals"] = 0
784
+ self.train_info["val_perf"] = macro_fscore
785
+
786
+ # Save the best model
787
+ logger.info("Saving best model")
788
+ self.save_model(self.best_model_path, last_checkpoint=False)
789
+ else:
790
+ self.train_info["num_stuck_evals"] += 1
791
+
792
+ # Save model
793
+ if self.config.trainer.to_save_model:
794
+ self.save_model(self.model_path, last_checkpoint=True)
795
+
796
+ # Go back to training mode
797
+ self.model.train()
798
+ return macro_fscore
799
+
800
+ @torch.no_grad()
801
+ def perform_final_eval(self) -> None:
802
+ """Method to evaluate the model after training has finished."""
803
+
804
+ self.model.eval()
805
+ base_output_dict = OmegaConf.to_container(self.config)
806
+ perf_summary = {"best_perf": self.train_info["val_perf"]}
807
+ if self.config.paths.model_dir:
808
+ perf_summary["model_dir"] = path.normpath(self.config.paths.model_dir)
809
+
810
+ logger.info(
811
+ "Max training memory: %.1f GB" % self.train_info.get("max_mem", 0.0)
812
+ )
813
+
814
+ logger.info("Validation performance: %.1f" % self.train_info["val_perf"])
815
+
816
+ perf_file_dict = {}
817
+ dataset_output_dict = {}
818
+
819
+ for split in ["dev", "test"]:
820
+ perf_summary[split] = {}
821
+ logger.info("\n")
822
+ logger.info("%s" % split.capitalize())
823
+ coref_dict = {}
824
+ for dataset in self.data_iter_map.get(split, []):
825
+ dataset_dir = path.join(self.config.paths.model_dir, dataset)
826
+ if not path.exists(dataset_dir):
827
+ os.makedirs(dataset_dir)
828
+
829
+ if dataset not in dataset_output_dict:
830
+ dataset_output_dict[dataset] = {}
831
+ if dataset not in perf_file_dict:
832
+ perf_file_dict[dataset] = path.join(dataset_dir, f"perf.json")
833
+
834
+ print("Dataset Name:", self.config.datasets[dataset].name)
835
+ logger.info("Dataset: %s\n" % self.config.datasets[dataset].name)
836
+
837
+ for go in [False]:
838
+ for tf in [False]:
839
+ result_dict = coref_evaluation(
840
+ self.config,
841
+ self.model,
842
+ self.data_iter_map,
843
+ dataset=dataset,
844
+ split=split,
845
+ teacher_force=tf,
846
+ gold_mentions=go,
847
+ final_eval=True,
848
+ conll_data_dir=self.conll_data_dir,
849
+ )
850
+ coref_dict[dataset] = result_dict
851
+ dataset_output_dict[dataset][split] = result_dict
852
+ perf_summary[split][dataset] = result_dict["f1_micro"]
853
+
854
+ if self.config.use_wandb:
855
+ self._wandb_log(
856
+ split=split,
857
+ stat_per_dataset={},
858
+ agg_stat=None,
859
+ coref_dict=coref_dict,
860
+ step=None,
861
+ )
862
+
863
+ sys.stdout.flush()
864
+
865
+ for dataset, output_dict in dataset_output_dict.items():
866
+ perf_file = perf_file_dict[dataset]
867
+ json.dump(output_dict, open(perf_file, "w"), indent=2)
868
+ logger.info("Final performance summary at %s" % path.abspath(perf_file))
869
+
870
+ summary_file = path.join(self.config.paths.model_dir, "perf.json")
871
+ json.dump(perf_summary, open(summary_file, "w"), indent=2)
872
+ logger.info("Performance summary file: %s" % path.abspath(summary_file))
873
+
874
+ def _initialize_best_model(self):
875
+ checkpoint = torch.load(
876
+ self.best_model_path,
877
+ map_location="cpu",
878
+ )
879
+ config = checkpoint["config"]
880
+
881
+ ## Due to version changes -- these changes are necessary
882
+ # if
883
+
884
+ if self.config.get("override_encoder", False):
885
+ model_config = config.model
886
+ print(type(self.config.model.doc_encoder.transformer))
887
+ print(self.config.model.doc_encoder.transformer)
888
+ model_config.doc_encoder.transformer = (
889
+ self.config.model.doc_encoder.transformer
890
+ )
891
+
892
+ # Override memory
893
+ # For e.g., can test with a different bounded memory size
894
+ if self.config.get("override_memory", False):
895
+ model_config = config.model
896
+ model_config.memory = self.config.model.memory
897
+
898
+ with open_dict(config):
899
+ print("Config change")
900
+ config.model.mention_params.ext_ment = (
901
+ self.config.model.mention_params.ext_ment
902
+ )
903
+ config = utils.fill_missing_configs(config, self.config)
904
+ print("Type: ", config.model.memory.type)
905
+
906
+ self.config.model = config.model
907
+
908
+ self.train_info = checkpoint["train_info"]
909
+
910
+ if self.config.model.doc_encoder.finetune:
911
+ # Load the document encoder params if encoder is finetuned
912
+ doc_encoder_dir = path.join(
913
+ path.dirname(self.best_model_path),
914
+ self.config.paths.doc_encoder_dirname,
915
+ )
916
+ if path.exists(doc_encoder_dir):
917
+ logger.info(
918
+ "Loading document encoder from %s" % path.abspath(doc_encoder_dir)
919
+ )
920
+ config.model.doc_encoder.transformer.model_str = doc_encoder_dir
921
+
922
+ self.model = EntityRankingModel(config.model, config.trainer)
923
+
924
+ # Document encoder parameters will be loaded via the huggingface initialization
925
+ self.model.load_state_dict(checkpoint["model"], strict=False)
926
+
927
+ if torch.cuda.is_available():
928
+ self.model.cuda(device=self.config.device)
929
+
930
+ def load_model(self, location: str, last_checkpoint=True) -> None:
931
+ """Load model from given location.
932
+
933
+ Args:
934
+ location: str
935
+ Location of checkpoint
936
+ last_checkpoint: bool
937
+ Whether the checkpoint is the last one saved or not.
938
+ If false, don't load optimizers, schedulers, and other training variables.
939
+ """
940
+
941
+ checkpoint = torch.load(location, map_location="cpu")
942
+ logger.info("Loading model from %s" % path.abspath(location))
943
+
944
+ # self.config = checkpoint["config"] ## Commented out so that it does not load the config of the trained model. Removed comment
945
+
946
+ self.model.load_state_dict(
947
+ checkpoint["model"], strict=False
948
+ ) ## No encoder in this model so strict=False is compulsary. No other weight missing. Checked
949
+
950
+ # self.train_info = checkpoint["train_info"] ## No train info transfer too. ## Transferring
951
+
952
+ if self.config.model.doc_encoder.finetune:
953
+ # Load the document encoder params if encoder is finetuned
954
+ doc_encoder_dir = path.join(
955
+ path.dirname(location), self.config.paths.doc_encoder_dirname
956
+ )
957
+ logger.info(
958
+ "Loading document encoder from %s" % path.abspath(doc_encoder_dir)
959
+ )
960
+
961
+ # Load the encoder
962
+ self.model.mention_proposer.doc_encoder.lm_encoder = (
963
+ AutoModel.from_pretrained(pretrained_model_name_or_path=doc_encoder_dir)
964
+ )
965
+ self.model.mention_proposer.doc_encoder.tokenizer = (
966
+ AutoTokenizer.from_pretrained(
967
+ pretrained_model_name_or_path=doc_encoder_dir,
968
+ clean_up_tokenization_spaces=True,
969
+ )
970
+ )
971
+ if self.model.mention_proposer.doc_encoder.config.finetune:
972
+ self.model.mention_proposer.doc_encoder.lm_encoder.gradient_checkpointing_enable()
973
+
974
+ if torch.cuda.is_available():
975
+ self.model.cuda(device=self.config.device)
976
+
977
+ print("Loaded Model:", torch.cuda.memory_summary())
978
+ print(
979
+ "Gradient checkpointing enabled? ", torch.autograd.grad_checkpoint_enabled()
980
+ )
981
+ del checkpoint
982
+ torch.cuda.empty_cache()
983
+
984
+ def save_model(self, location: os.PathLike, last_checkpoint=True) -> None:
985
+ """Save model.
986
+
987
+ Args:
988
+ location: Location of checkpoint
989
+ last_checkpoint:
990
+ Whether the checkpoint is the last one saved or not.
991
+ If false, don't save optimizers and schedulers which take up a lot of space.
992
+ """
993
+
994
+ model_state_dict = OrderedDict(self.model.state_dict())
995
+ doc_encoder_state_dict = {}
996
+
997
+ # Separate the doc_encoder state dict
998
+ # We will save the model in two parts:
999
+ # (a) Doc encoder parameters - Useful for final upload to huggingface
1000
+ # (b) Rest of the model parameters, optimizers, schedulers, and other bookkeeping variables
1001
+ for key in self.model.state_dict():
1002
+ if "lm_encoder." in key:
1003
+ doc_encoder_state_dict[key] = model_state_dict[key]
1004
+ del model_state_dict[key]
1005
+
1006
+ # Save the document encoder params
1007
+ if self.config.model.doc_encoder.finetune:
1008
+ doc_encoder_dir = path.join(
1009
+ path.dirname(location), self.config.paths.doc_encoder_dirname
1010
+ )
1011
+ if not path.exists(doc_encoder_dir):
1012
+ os.makedirs(doc_encoder_dir)
1013
+
1014
+ logger.info(f"Encoder saved at {path.abspath(doc_encoder_dir)}")
1015
+ # Save the encoder
1016
+ self.model.mention_proposer.doc_encoder.lm_encoder.save_pretrained(
1017
+ save_directory=doc_encoder_dir, save_config=True
1018
+ )
1019
+ # Save the tokenizer
1020
+ self.model.mention_proposer.doc_encoder.tokenizer.save_pretrained(
1021
+ doc_encoder_dir
1022
+ )
1023
+
1024
+ save_dict = {
1025
+ "train_info": self.train_info,
1026
+ "model": model_state_dict,
1027
+ "rng_state": torch.get_rng_state(),
1028
+ "np_rng_state": np.random.get_state(),
1029
+ "config": self.config,
1030
+ }
1031
+
1032
+ if self.scaler is not None:
1033
+ save_dict["scaler"] = self.scaler.state_dict()
1034
+
1035
+ if last_checkpoint:
1036
+ # For last checkpoint save the optimizer and scheduler states as well
1037
+ save_dict["optimizer"] = {}
1038
+ save_dict["scheduler"] = {}
1039
+
1040
+ param_groups: List[str] = (
1041
+ ["mem", "doc"] if self.config.model.doc_encoder.finetune else ["mem"]
1042
+ )
1043
+ for param_group in param_groups:
1044
+ save_dict["optimizer"][param_group] = self.optimizer[
1045
+ param_group
1046
+ ].state_dict()
1047
+ save_dict["scheduler"][param_group] = self.optim_scheduler[
1048
+ param_group
1049
+ ].state_dict()
1050
+
1051
+ torch.save(save_dict, location)
1052
+ logger.info(f"Model saved at: {path.abspath(location)}")