File size: 1,957 Bytes
5953ef9 41eec42 5953ef9 41eec42 5953ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import traceback
from typing import Any, Dict, List
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
from rex.utils.initialization import set_seed_and_log_path
from rex.utils.logging import logger
from src.task import SchemaGuidedInstructBertTask
set_seed_and_log_path(log_path="debug.log")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
class RequestData(BaseModel):
data: List[Dict[str, Any]]
task = SchemaGuidedInstructBertTask.from_taskdir(
"mirror_outputs/Mirror_Pretrain_AllExcluded_2",
load_best_model=True,
initialize=False,
dump_configfile=False,
update_config={
"regenerate_cache": False,
},
)
@app.post("/process")
def process_data(data: RequestData):
input_data = data.data
ok = True
msg = ""
results = {}
try:
results = task.predict(input_data)
msg = "success"
except KeyboardInterrupt:
raise KeyboardInterrupt
except Exception:
ok = False
msg = traceback.format_exc()
# Return the processed data
logger.info(f"Data: {input_data}, Prediction: {results}")
return {"ok": ok, "msg": msg, "results": results}
@app.get("/")
async def api():
return FileResponse("./index.html", media_type="text/html")
if __name__ == "__main__":
log_config = uvicorn.config.LOGGING_CONFIG
log_config["formatters"]["access"]["fmt"] = (
"%(asctime)s | " + log_config["formatters"]["access"]["fmt"]
)
log_config["formatters"]["default"]["fmt"] = (
"%(asctime)s | " + log_config["formatters"]["default"]["fmt"]
)
uvicorn.run(
"src.app.api_backend:app",
host="0.0.0.0",
port=7860,
log_level="debug",
log_config=log_config,
reload=True,
)
|