Update script/serve_model.py
Browse files- script/serve_model.py +12 -1
script/serve_model.py
CHANGED
@@ -301,6 +301,17 @@ if __name__ == "__main__":
|
|
301 |
crs_model = CRSModel(crs_model=args.crs_model, **model_args)
|
302 |
logger.info(f"Loaded {args.crs_model} model.")
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
# Start CRS Flask server
|
305 |
-
crs_server = CRSFlaskServer(
|
|
|
|
|
306 |
crs_server.start(args.host, args.port)
|
|
|
301 |
crs_model = CRSModel(crs_model=args.crs_model, **model_args)
|
302 |
logger.info(f"Loaded {args.crs_model} model.")
|
303 |
|
304 |
+
# Generation arguments
|
305 |
+
response_generation_args = {}
|
306 |
+
if args.crs_model == "unicrs":
|
307 |
+
response_generation_args = {
|
308 |
+
"movie_token": (
|
309 |
+
"<movie>" if args.kg_dataset.startswith("redial") else "<mask>"
|
310 |
+
),
|
311 |
+
}
|
312 |
+
|
313 |
# Start CRS Flask server
|
314 |
+
crs_server = CRSFlaskServer(
|
315 |
+
crs_model, args.kg_dataset, response_generation_args
|
316 |
+
)
|
317 |
crs_server.start(args.host, args.port)
|