a1c00l's picture
Update src/aibom_generator/api.py
d21f364 verified
raw
history blame
3.76 kB
"""
FastAPI server for the AIBOM Generator with minimal UI.
"""
import logging
import os
from typing import Dict, List, Optional, Any
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from aibom_generator.generator import AIBOMGenerator
from aibom_generator.utils import setup_logging, calculate_completeness_score
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="AIBOM Generator API",
description="API for generating AI Bills of Materials (AIBOMs) in CycloneDX format for Hugging Face models.",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize templates
templates = Jinja2Templates(directory="templates")
# Create generator instance
generator = AIBOMGenerator(
hf_token=os.environ.get("HF_TOKEN"),
inference_model_url=os.environ.get("AIBOM_INFERENCE_URL"),
use_inference=os.environ.get("AIBOM_USE_INFERENCE", "true").lower() == "true",
cache_dir=os.environ.get("AIBOM_CACHE_DIR"),
)
# Define request and response models
class GenerateRequest(BaseModel):
model_id: str
include_inference: Optional[bool] = None
completeness_threshold: Optional[int] = 0
class GenerateResponse(BaseModel):
aibom: Dict[str, Any]
completeness_score: int
model_id: str
class StatusResponse(BaseModel):
status: str
version: str
# Web UI endpoint
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/generate", response_class=HTMLResponse)
async def generate_from_ui(request: Request, model_id: str = Form(...)):
try:
aibom = generator.generate_aibom(model_id=model_id)
completeness_score = calculate_completeness_score(aibom)
return templates.TemplateResponse(
"result.html",
{"request": request, "aibom": aibom, "completeness_score": completeness_score, "model_id": model_id},
)
except Exception as e:
logger.error(f"Error generating AIBOM: {e}")
return templates.TemplateResponse(
"error.html",
{"request": request, "error": str(e)},
)
# Original JSON API endpoints (kept unchanged)
@app.post("/generate/json", response_model=GenerateResponse)
async def generate_aibom(request: GenerateRequest):
try:
aibom = generator.generate_aibom(
model_id=request.model_id,
include_inference=request.include_inference,
)
completeness_score = calculate_completeness_score(aibom)
if completeness_score < request.completeness_threshold:
raise HTTPException(
status_code=400,
detail=f"AIBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})",
)
return {
"aibom": aibom,
"completeness_score": completeness_score,
"model_id": request.model_id,
}
except Exception as e:
logger.error(f"Error generating AIBOM: {e}")
raise HTTPException(
status_code=500,
detail=f"Error generating AIBOM: {str(e)}",
)
@app.get("/health")
async def health():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))