|
import os |
|
from django.core.asgi import get_asgi_application |
|
from fastapi import FastAPI, Request, Depends, HTTPException, status |
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from starlette.middleware.cors import CORSMiddleware |
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
from starlette.responses import JSONResponse, RedirectResponse |
|
import gradio as gr |
|
import secrets |
|
|
|
from mysite.routers.gradio import setup_gradio_interfaces |
|
from mysite.routers.fastapi import setup_webhook_routes, include_routers |
|
from mysite.routers.database import setup_database_routes |
|
from mysite.config.asgi_config import init_django_app |
|
from interpreter import interpreter |
|
import mysite.interpreter.interpreter_config |
|
|
|
from mysite.logger import logger |
|
|
|
security = HTTPBasic() |
|
|
|
|
|
users = { |
|
"username1": "password1", |
|
"username2": "password2" |
|
} |
|
|
|
def authenticate(credentials: HTTPBasicCredentials = Depends(security)): |
|
correct_username = credentials.username in users |
|
correct_password = secrets.compare_digest(credentials.password, users[credentials.username]) if correct_username else False |
|
if not (correct_username and correct_password): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid credentials", |
|
headers={"WWW-Authenticate": "Basic"}, |
|
) |
|
return credentials.username |
|
|
|
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "mysite.settings") |
|
application = get_asgi_application() |
|
|
|
app = FastAPI() |
|
|
|
|
|
init_django_app(app, application) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class BasicAuthMiddleware(BaseHTTPMiddleware): |
|
async def dispatch(self, request, call_next): |
|
if request.url.path.startswith("/gradio"): |
|
credentials = security(request) |
|
try: |
|
authenticate(credentials) |
|
except HTTPException as exc: |
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) |
|
response = await call_next(request) |
|
return response |
|
|
|
app.add_middleware(BasicAuthMiddleware) |
|
|
|
|
|
gradio_interfaces = setup_gradio_interfaces() |
|
|
|
|
|
include_routers(app) |
|
setup_webhook_routes(app) |
|
|
|
|
|
setup_database_routes(app) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
app = gr.mount_gradio_app(app, gradio_interfaces, "/gradio") |
|
|
|
|
|
templates = Jinja2Templates(directory="static") |
|
|
|
@app.get("/") |
|
def read_root(): |
|
return RedirectResponse(url="/gradio") |
|
|
|
@app.get("/test") |
|
def get_some_page(request: Request): |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|