fuhsiao commited on
Commit
89ea49b
·
1 Parent(s): dfcc660
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -1,26 +1,22 @@
1
  from utils import *
2
  import gradio as gr
3
 
4
-
5
- from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
6
 
7
  def download_model():
8
  # 下載並快取SentenceTransformer所需的模型和tokenizer
9
- sentence_transformer_model = "sentence-transformers/all-MiniLM-L6-v2"
10
- PreTrainedModel.from_pretrained(sentence_transformer_model)
11
- PreTrainedTokenizer.from_pretrained(sentence_transformer_model)
12
 
13
  # 下載並快取AutoTokenizer所需的模型
14
  biobart_model = "fuhsiao/BioBART-PMC-EXT-Section"
15
- PreTrainedModel.from_pretrained(biobart_model)
16
- AutoTokenizer.from_pretrained(biobart_model)
17
  AutoModel.from_pretrained(biobart_model)
 
18
 
19
  # 下載並快取AutoModelForSeq2SeqLM所需的模型
20
  bart_model = "fuhsiao/BART-PMC-EXT-Section"
21
- PreTrainedModel.from_pretrained(bart_model)
22
- AutoTokenizer.from_pretrained(bart_model)
23
  AutoModelForSeq2SeqLM.from_pretrained(bart_model)
 
24
 
25
  return True
26
 
@@ -28,6 +24,7 @@ def download_model():
28
 
29
 
30
 
 
31
  def main(file, ext_threshold, article_type):
32
 
33
  if file is None or ext_threshold is None or article_type is None:
 
1
  from utils import *
2
  import gradio as gr
3
 
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
6
 
7
  def download_model():
8
  # 下載並快取SentenceTransformer所需的模型和tokenizer
9
+ SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
10
 
11
  # 下載並快取AutoTokenizer所需的模型
12
  biobart_model = "fuhsiao/BioBART-PMC-EXT-Section"
 
 
13
  AutoModel.from_pretrained(biobart_model)
14
+ AutoTokenizer.from_pretrained(biobart_model)
15
 
16
  # 下載並快取AutoModelForSeq2SeqLM所需的模型
17
  bart_model = "fuhsiao/BART-PMC-EXT-Section"
 
 
18
  AutoModelForSeq2SeqLM.from_pretrained(bart_model)
19
+ AutoTokenizer.from_pretrained(bart_model)
20
 
21
  return True
22
 
 
24
 
25
 
26
 
27
+
28
  def main(file, ext_threshold, article_type):
29
 
30
  if file is None or ext_threshold is None or article_type is None: