Sudipta Nayak
file path changed
0499cdb
raw
history blame
2.94 kB
import os
from fastapi import FastAPI, Request, UploadFile
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import subprocess
from pathlib import Path
from config import settings
import logging
app = FastAPI(
title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
# app.mount("/static", FileResponse, name="static")
# templates = Jinja2Templates(directory="templates")
app.mount("/static", FileResponse, name="static")
templates = Jinja2Templates(directory="app/templates")
class Item(BaseModel):
file: UploadFile
@app.get("/", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/detect/", response_class=HTMLResponse)
async def detect_objects(request: Request, file: UploadFile):
try:
print('File name:', file.filename)
input_file_path = f"app/static/{file.filename}"
output_file_path = f"app/static/output"
# Save the uploaded file
with open(input_file_path, "wb") as input_file:
input_file.write(await file.read())
print('Detect start')
# Run YOLOv7 detection and save output
subprocess.run(["python", "app/detect.py", "--conf", "0.5", "--img-size", "640", "--weights", "app/model/best.pt", "--no-trace",
"--source", str(input_file_path), "--save-txt", "--save-conf", "--exist-ok", "--project", str(output_file_path)])
print('Detect end')
# original_image_path = str(input_file_path)
# output_image_path = f"{output_file_path}/exp/{file.filename}"
original_image_path = f"../static/{file.filename}"
output_image_path = f"../static/output/{file.filename}"
print('original_image path :', str(original_image_path))
print('output_image_path path :', str(output_image_path))
# Render HTML using Jinja2Templates
return templates.TemplateResponse(
"result.html",
{"request": request, "original_image": str(original_image_path), "output_image": str(output_image_path)}
)
# return StreamingResponse(open(output_file, "rb"), media_type="video/mp4")
except Exception as e:
logging.error(f"Error in /detect endpoint: {str(e)}")
raise e
# Set all CORS enabled origins
# if settings.BACKEND_CORS_ORIGINS:
# app.add_middleware(
# CORSMiddleware,
# allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# Start app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)