Spaces:
Running
Running
""" | |
FastAPI server for the AIBOM Generator with minimal UI. | |
""" | |
import logging | |
import os | |
from typing import Dict, List, Optional, Any | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.templating import Jinja2Templates | |
from pydantic import BaseModel | |
from aibom_generator.generator import AIBOMGenerator | |
from aibom_generator.utils import setup_logging, calculate_completeness_score | |
# Set up logging | |
setup_logging() | |
logger = logging.getLogger(__name__) | |
# Create FastAPI app | |
app = FastAPI( | |
title="AIBOM Generator API", | |
description="API for generating AI Bills of Materials (AIBOMs) in CycloneDX format for Hugging Face models.", | |
version="0.1.0", | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize templates | |
templates = Jinja2Templates(directory="templates") | |
# Create generator instance | |
generator = AIBOMGenerator( | |
hf_token=os.environ.get("HF_TOKEN"), | |
inference_model_url=os.environ.get("AIBOM_INFERENCE_URL"), | |
use_inference=os.environ.get("AIBOM_USE_INFERENCE", "true").lower() == "true", | |
cache_dir=os.environ.get("AIBOM_CACHE_DIR"), | |
) | |
# Define request and response models | |
class GenerateRequest(BaseModel): | |
model_id: str | |
include_inference: Optional[bool] = None | |
completeness_threshold: Optional[int] = 0 | |
class GenerateResponse(BaseModel): | |
aibom: Dict[str, Any] | |
completeness_score: int | |
model_id: str | |
class StatusResponse(BaseModel): | |
status: str | |
version: str | |
# Web UI endpoint | |
async def home(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def generate_from_ui(request: Request, model_id: str = Form(...)): | |
try: | |
aibom = generator.generate_aibom(model_id=model_id) | |
completeness_score = calculate_completeness_score(aibom) | |
return templates.TemplateResponse( | |
"result.html", | |
{"request": request, "aibom": aibom, "completeness_score": completeness_score, "model_id": model_id}, | |
) | |
except Exception as e: | |
logger.error(f"Error generating AIBOM: {e}") | |
return templates.TemplateResponse( | |
"error.html", | |
{"request": request, "error": str(e)}, | |
) | |
# Original JSON API endpoints (kept unchanged) | |
async def generate_aibom(request: GenerateRequest): | |
try: | |
aibom = generator.generate_aibom( | |
model_id=request.model_id, | |
include_inference=request.include_inference, | |
) | |
completeness_score = calculate_completeness_score(aibom) | |
if completeness_score < request.completeness_threshold: | |
raise HTTPException( | |
status_code=400, | |
detail=f"AIBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})", | |
) | |
return { | |
"aibom": aibom, | |
"completeness_score": completeness_score, | |
"model_id": request.model_id, | |
} | |
except Exception as e: | |
logger.error(f"Error generating AIBOM: {e}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error generating AIBOM: {str(e)}", | |
) | |
async def health(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |