pszemraj commited on
Commit
8312087
·
1 Parent(s): 04190ea

⚡️ improve performance, enable longer text

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (3) hide show
  1. app.py +3 -3
  2. summarize.py +11 -0
  3. utils.py +12 -3
app.py CHANGED
@@ -73,11 +73,11 @@ def predict(
73
  batch_length=token_batch_length,
74
  **settings,
75
  )
76
-
77
  del model
78
  del tokenizer
79
  gc.collect()
80
-
81
  return summaries
82
 
83
 
@@ -89,7 +89,7 @@ def proc_submission(
89
  length_penalty: float,
90
  repetition_penalty: float,
91
  no_repeat_ngram_size: int,
92
- max_input_length: int = 2048,
93
  ):
94
  """
95
  proc_submission - a helper function for the gradio module to process submissions
 
73
  batch_length=token_batch_length,
74
  **settings,
75
  )
76
+
77
  del model
78
  del tokenizer
79
  gc.collect()
80
+
81
  return summaries
82
 
83
 
 
89
  length_penalty: float,
90
  repetition_penalty: float,
91
  no_repeat_ngram_size: int,
92
+ max_input_length: int = 4096,
93
  ):
94
  """
95
  proc_submission - a helper function for the gradio module to process submissions
summarize.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  from tqdm.auto import tqdm
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
 
 
 
9
 
10
  def load_model_and_tokenizer(model_name: str) -> tuple:
11
  """
@@ -24,6 +26,15 @@ def load_model_and_tokenizer(model_name: str) -> tuple:
24
 
25
  logging.info(f"Loaded model {model_name} to {device}")
26
 
 
 
 
 
 
 
 
 
 
27
  return model, tokenizer
28
 
29
 
 
6
  from tqdm.auto import tqdm
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
 
9
+ from utils import validate_pytorch2
10
+
11
 
12
  def load_model_and_tokenizer(model_name: str) -> tuple:
13
  """
 
26
 
27
  logging.info(f"Loaded model {model_name} to {device}")
28
 
29
+ if validate_pytorch2():
30
+ try:
31
+ logging.info("Compiling model with Torch 2.0")
32
+ model = torch.compile(model)
33
+ except Exception as e:
34
+ logging.warning(f"Could not compile model with Torch 2.0: {e}")
35
+ else:
36
+ logging.info("Torch 2.0 not detected, skipping compilation")
37
+
38
  return model, tokenizer
39
 
40
 
utils.py CHANGED
@@ -3,10 +3,20 @@
3
  """
4
 
5
  import re
6
- from pathlib import Path
7
  from datetime import datetime
 
 
 
8
  from natsort import natsorted
9
- import subprocess
 
 
 
 
 
 
 
10
 
11
 
12
  def get_timestamp() -> str:
@@ -114,7 +124,6 @@ def saves_summary(summarize_output, outpath: str or Path = None, add_signature=T
114
  outpath,
115
  "a",
116
  ) as fo:
117
-
118
  fo.write("\n" * 3)
119
  fo.write(f"\n\nSection Scores:\n")
120
  fo.writelines(scores_text)
 
3
  """
4
 
5
  import re
6
+ import subprocess
7
  from datetime import datetime
8
+ from pathlib import Path
9
+
10
+ import torch
11
  from natsort import natsorted
12
+
13
+
14
+ def validate_pytorch2(torch_version: str = None):
15
+ torch_version = torch.__version__ if torch_version is None else torch_version
16
+
17
+ pattern = r"^2\.\d+(\.\d+)*"
18
+
19
+ return True if re.match(pattern, torch_version) else False
20
 
21
 
22
  def get_timestamp() -> str:
 
124
  outpath,
125
  "a",
126
  ) as fo:
 
127
  fo.write("\n" * 3)
128
  fo.write(f"\n\nSection Scores:\n")
129
  fo.writelines(scores_text)