Nol00 commited on
Commit
c817ff0
·
verified ·
1 Parent(s): ef03bcf

Update script/serve_model.py

Browse files
Files changed (1) hide show
  1. script/serve_model.py +12 -1
script/serve_model.py CHANGED
@@ -301,6 +301,17 @@ if __name__ == "__main__":
301
  crs_model = CRSModel(crs_model=args.crs_model, **model_args)
302
  logger.info(f"Loaded {args.crs_model} model.")
303
 
 
 
 
 
 
 
 
 
 
304
  # Start CRS Flask server
305
- crs_server = CRSFlaskServer(crs_model, args.kg_dataset)
 
 
306
  crs_server.start(args.host, args.port)
 
301
  crs_model = CRSModel(crs_model=args.crs_model, **model_args)
302
  logger.info(f"Loaded {args.crs_model} model.")
303
 
304
+ # Generation arguments
305
+ response_generation_args = {}
306
+ if args.crs_model == "unicrs":
307
+ response_generation_args = {
308
+ "movie_token": (
309
+ "<movie>" if args.kg_dataset.startswith("redial") else "<mask>"
310
+ ),
311
+ }
312
+
313
  # Start CRS Flask server
314
+ crs_server = CRSFlaskServer(
315
+ crs_model, args.kg_dataset, response_generation_args
316
+ )
317
  crs_server.start(args.host, args.port)