MeetJivani commited on
Commit
5407b63
·
1 Parent(s): 663e530

Update summarize.py

Browse files
Files changed (1) hide show
  1. summarize.py +11 -3
summarize.py CHANGED
@@ -20,15 +20,23 @@ def load_model_and_tokenizer(model_name: str) -> tuple:
20
  :param str model_name: the model name/ID on the hub
21
  :return tuple: a tuple containing the model and tokenizer
22
  """
 
 
 
 
 
 
 
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
  model = AutoModelForSeq2SeqLM.from_pretrained(
25
- model_name,
26
  ).to(device)
27
  model = model.eval()
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
 
31
- logging.info(f"Loaded model {model_name} to {device}")
32
 
33
  if validate_pytorch2():
34
  try:
 
20
  :param str model_name: the model name/ID on the hub
21
  :return tuple: a tuple containing the model and tokenizer
22
  """
23
+ MODEL_OPTIONS = {
24
+ "Model 1": "pszemraj/long-t5-tglobal-base-16384-book-summary",
25
+ "Model 2": "pszemraj/long-t5-tglobal-base-sci-simplify",
26
+ "Model 3": "pszemraj/long-t5-tglobal-base-sci-simplify-elife",
27
+ "Model 4": "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
28
+ "Model 5": "pszemraj/pegasus-x-large-book-summary",
29
+ }
30
+ selected_model_identifier = MODEL_OPTIONS.get(user_selected_model_name)
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  model = AutoModelForSeq2SeqLM.from_pretrained(
33
+ selected_model_identifier,
34
  ).to(device)
35
  model = model.eval()
36
 
37
+ tokenizer = AutoTokenizer.from_pretrained(selected_model_identifier)
38
 
39
+ logging.info(f"Loaded model {selected_model_identifier} to {device}")
40
 
41
  if validate_pytorch2():
42
  try: