Update main.py
Browse files
main.py
CHANGED
@@ -160,6 +160,24 @@ app = FastAPI()
|
|
160 |
|
161 |
pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
@app.get("/infer_t5")
|
164 |
def t5(input):
|
165 |
output = pipe_flan(input)
|
|
|
160 |
|
161 |
pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
|
162 |
|
163 |
+
def model_init():
|
164 |
+
args = set_args()
|
165 |
+
logger = create_logger(args)
|
166 |
+
# 当用户使用GPU,并且GPU可用时
|
167 |
+
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
168 |
+
device = 'cuda' if args.cuda else 'cpu'
|
169 |
+
logger.info('using device:{}'.format(device))
|
170 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
171 |
+
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
172 |
+
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
173 |
+
model = Word_BERT()
|
174 |
+
# model = model.load_state_dict(torch.load(args.model_path))
|
175 |
+
model = model.to(device)
|
176 |
+
model.eval()
|
177 |
+
return model
|
178 |
+
|
179 |
+
model = model_init()
|
180 |
+
|
181 |
@app.get("/infer_t5")
|
182 |
def t5(input):
|
183 |
output = pipe_flan(input)
|