a1c00l commited on
Commit
d21f364
·
verified ·
1 Parent(s): b1b6cd8

Update src/aibom_generator/api.py

Browse files
Files changed (1) hide show
  1. src/aibom_generator/api.py +36 -64
src/aibom_generator/api.py CHANGED
@@ -1,13 +1,15 @@
1
  """
2
- FastAPI server for the AIBOM Generator.
3
  """
4
 
5
  import logging
6
  import os
7
- from typing import Dict, List, Optional, Any, Union
8
 
9
- from fastapi import FastAPI, HTTPException, BackgroundTasks
10
  from fastapi.middleware.cors import CORSMiddleware
 
 
11
  from pydantic import BaseModel
12
 
13
  from aibom_generator.generator import AIBOMGenerator
@@ -33,6 +35,9 @@ app.add_middleware(
33
  allow_headers=["*"],
34
  )
35
 
 
 
 
36
  # Create generator instance
37
  generator = AIBOMGenerator(
38
  hf_token=os.environ.get("HF_TOKEN"),
@@ -60,36 +65,46 @@ class StatusResponse(BaseModel):
60
  version: str
61
 
62
 
63
- # Define API endpoints
64
- @app.get("/", response_model=StatusResponse)
65
- async def root():
66
- """Get API status."""
67
- return {
68
- "status": "ok",
69
- "version": "0.1.0",
70
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
- @app.post("/generate", response_model=GenerateResponse)
 
74
  async def generate_aibom(request: GenerateRequest):
75
- """Generate an AIBOM for a Hugging Face model."""
76
  try:
77
- # Generate the AIBOM
78
  aibom = generator.generate_aibom(
79
  model_id=request.model_id,
80
  include_inference=request.include_inference,
81
  )
82
-
83
- # Calculate completeness score
84
  completeness_score = calculate_completeness_score(aibom)
85
-
86
- # Check if it meets the threshold
87
  if completeness_score < request.completeness_threshold:
88
  raise HTTPException(
89
  status_code=400,
90
  detail=f"AIBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})",
91
  )
92
-
93
  return {
94
  "aibom": aibom,
95
  "completeness_score": completeness_score,
@@ -103,55 +118,12 @@ async def generate_aibom(request: GenerateRequest):
103
  )
104
 
105
 
106
- @app.post("/generate/async")
107
- async def generate_aibom_async(
108
- request: GenerateRequest,
109
- background_tasks: BackgroundTasks,
110
- ):
111
- """Generate an AIBOM asynchronously for a Hugging Face model."""
112
- # Add to background tasks
113
- background_tasks.add_task(
114
- _generate_aibom_background,
115
- request.model_id,
116
- request.include_inference,
117
- request.completeness_threshold,
118
- )
119
-
120
- return {
121
- "status": "accepted",
122
- "message": f"AIBOM generation for {request.model_id} started in the background",
123
- }
124
-
125
-
126
- async def _generate_aibom_background(
127
- model_id: str,
128
- include_inference: Optional[bool] = None,
129
- completeness_threshold: Optional[int] = 0,
130
- ):
131
- """Generate an AIBOM in the background."""
132
- try:
133
- # Generate the AIBOM
134
- aibom = generator.generate_aibom(
135
- model_id=model_id,
136
- include_inference=include_inference,
137
- )
138
-
139
- # Calculate completeness score
140
- completeness_score = calculate_completeness_score(aibom)
141
-
142
- # TODO: Store the result or notify the user
143
- logger.info(f"Background AIBOM generation completed for {model_id}")
144
- logger.info(f"Completeness score: {completeness_score}")
145
- except Exception as e:
146
- logger.error(f"Error in background AIBOM generation for {model_id}: {e}")
147
-
148
-
149
  @app.get("/health")
150
  async def health():
151
- """Health check endpoint."""
152
  return {"status": "healthy"}
153
 
154
 
155
  if __name__ == "__main__":
156
  import uvicorn
157
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 5000)))
 
 
1
  """
2
+ FastAPI server for the AIBOM Generator with minimal UI.
3
  """
4
 
5
  import logging
6
  import os
7
+ from typing import Dict, List, Optional, Any
8
 
9
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import HTMLResponse, JSONResponse
12
+ from fastapi.templating import Jinja2Templates
13
  from pydantic import BaseModel
14
 
15
  from aibom_generator.generator import AIBOMGenerator
 
35
  allow_headers=["*"],
36
  )
37
 
38
+ # Initialize templates
39
+ templates = Jinja2Templates(directory="templates")
40
+
41
  # Create generator instance
42
  generator = AIBOMGenerator(
43
  hf_token=os.environ.get("HF_TOKEN"),
 
65
  version: str
66
 
67
 
68
+ # Web UI endpoint
69
+ @app.get("/", response_class=HTMLResponse)
70
+ async def home(request: Request):
71
+ return templates.TemplateResponse("index.html", {"request": request})
72
+
73
+
74
+ @app.post("/generate", response_class=HTMLResponse)
75
+ async def generate_from_ui(request: Request, model_id: str = Form(...)):
76
+ try:
77
+ aibom = generator.generate_aibom(model_id=model_id)
78
+ completeness_score = calculate_completeness_score(aibom)
79
+
80
+ return templates.TemplateResponse(
81
+ "result.html",
82
+ {"request": request, "aibom": aibom, "completeness_score": completeness_score, "model_id": model_id},
83
+ )
84
+ except Exception as e:
85
+ logger.error(f"Error generating AIBOM: {e}")
86
+ return templates.TemplateResponse(
87
+ "error.html",
88
+ {"request": request, "error": str(e)},
89
+ )
90
 
91
 
92
+ # Original JSON API endpoints (kept unchanged)
93
+ @app.post("/generate/json", response_model=GenerateResponse)
94
  async def generate_aibom(request: GenerateRequest):
 
95
  try:
 
96
  aibom = generator.generate_aibom(
97
  model_id=request.model_id,
98
  include_inference=request.include_inference,
99
  )
 
 
100
  completeness_score = calculate_completeness_score(aibom)
101
+
 
102
  if completeness_score < request.completeness_threshold:
103
  raise HTTPException(
104
  status_code=400,
105
  detail=f"AIBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})",
106
  )
107
+
108
  return {
109
  "aibom": aibom,
110
  "completeness_score": completeness_score,
 
118
  )
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  @app.get("/health")
122
  async def health():
 
123
  return {"status": "healthy"}
124
 
125
 
126
  if __name__ == "__main__":
127
  import uvicorn
128
+
129
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))