a1c00l commited on
Commit
c5073ea
·
verified ·
1 Parent(s): 8acee36

Update src/aibom_generator/api.py

Browse files
Files changed (1) hide show
  1. src/aibom_generator/api.py +290 -78
src/aibom_generator/api.py CHANGED
@@ -1,109 +1,321 @@
1
- import logging
2
  import os
 
 
3
  from typing import Dict, List, Optional, Any, Union
 
 
4
 
5
- from fastapi import FastAPI, HTTPException, BackgroundTasks
6
- from fastapi.middleware.cors import CORSMiddleware
 
 
 
7
 
 
 
 
8
 
9
- from pydantic import BaseModel
 
10
 
11
- from aibom_generator.generator import AIBOMGenerator
12
- allow_headers=["*"],
13
- )
14
 
 
 
15
 
16
- # Create generator instance
17
- generator = AIBOMGenerator(
18
- hf_token=os.environ.get("HF_TOKEN"),
19
- version: str
20
 
 
 
21
 
22
- # Define API endpoints
23
- @app.get("/", response_model=StatusResponse)
24
- async def root():
25
- """Get API status."""
26
- return {
27
- "status": "ok",
28
- "version": "1.0.0",
29
- }
 
 
 
 
 
30
 
31
- @app.post("/generate", response_model=GenerateResponse)
 
 
 
32
 
33
- async def generate_aibom(request: GenerateRequest):
34
- """Generate an AI SBOM for a Hugging Face model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- # Generate the AIBOM
37
- aibom = generator.generate_aibom(
38
- model_id=request.model_id,
39
- include_inference=request.include_inference,
40
- )
41
-
42
- # Calculate completeness score
43
- completeness_score = calculate_completeness_score(aibom)
44
 
45
- # Check if it meets the threshold
46
- if completeness_score < request.completeness_threshold:
47
- raise HTTPException(
48
- status_code=400,
49
- detail=f"AI SBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})",
50
- )
51
-
52
- return {
53
- "aibom": aibom,
54
- "completeness_score": completeness_score,
55
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
 
 
57
 
58
- @app.post("/generate/async")
59
- async def generate_aibom_async(
60
- request: GenerateRequest,
61
- background_tasks: BackgroundTasks,
62
- ):
63
- """Generate an AI SBOM asynchronously for a Hugging Face model."""
64
- # Add to background tasks
65
- background_tasks.add_task(
66
- _generate_aibom_background,
67
- request.model_id,
68
- request.include_inference,
69
- request.completeness_threshold,
70
- )
71
-
72
  return {
73
- "status": "accepted",
74
- "message": f"AI SBOM generation for {request.model_id} started in the background",
 
75
  }
76
 
77
-
78
- async def _generate_aibom_background(
79
- model_id: str,
80
- include_inference: Optional[bool] = None,
81
- completeness_threshold: Optional[int] = 0,
 
 
82
  ):
83
- """Generate an AI SBOM in the background."""
84
  try:
85
- # Generate the AIBOM
86
- aibom = generator.generate_aibom(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model_id=model_id,
88
  include_inference=include_inference,
 
89
  )
90
 
91
- # Calculate completeness score
92
- completeness_score = calculate_completeness_score(aibom)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # TODO: Store the result or notify the user
95
- logger.info(f"Background AI SBOM generation completed for {model_id}")
96
- logger.info(f"Completeness score: {completeness_score}")
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- logger.error(f"Error in background AI SBOM generation for {model_id}: {e}")
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- @app.get("/health")
102
- async def health():
103
- """Health check endpoint."""
104
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- if __name__ == "__main__":
108
- import uvicorn
109
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 5000)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import logging
4
  from typing import Dict, List, Optional, Any, Union
5
+ from datetime import datetime
6
+ from pathlib import Path
7
 
8
+ from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Query, File, UploadFile, Form, Request
9
+ from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
10
+ from fastapi.staticfiles import StaticFiles
11
+ from fastapi.templating import Jinja2Templates
12
+ from pydantic import BaseModel, Field
13
 
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
 
18
+ # Define templates directory
19
+ templates_dir = "templates"
20
 
21
+ # Initialize templates with a simple path
22
+ templates = Jinja2Templates(directory=templates_dir)
 
23
 
24
+ # Create app
25
+ app = FastAPI(title="AI SBOM Generator API")
26
 
27
+ # Define output directory for generated AIBOMs
28
+ OUTPUT_DIR = "/tmp/aibom_output"
 
 
29
 
30
+ # Create output directory if it doesn't exist
31
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
32
 
33
+ # Mount output directory as static files
34
+ app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output")
35
+
36
+ # Define models
37
+ class GenerateRequest(BaseModel):
38
+ model_id: str
39
+ include_inference: bool = False
40
+ use_best_practices: bool = True
41
+
42
+ class GenerateWithReportRequest(BaseModel):
43
+ model_id: str
44
+ include_inference: bool = False
45
+ use_best_practices: bool = True
46
 
47
+ class BatchGenerateRequest(BaseModel):
48
+ model_ids: List[str]
49
+ include_inference: bool = False
50
+ use_best_practices: bool = True
51
 
52
+ class ModelScoreRequest(BaseModel):
53
+ model_id: str
54
+
55
+ class StatusResponse(BaseModel):
56
+ status: str
57
+ version: str
58
+ generator_version: str
59
+
60
+ class BatchJobResponse(BaseModel):
61
+ job_id: str
62
+ status: str
63
+ model_ids: List[str]
64
+ message: str
65
+
66
+ # Startup event to ensure directories exist
67
+ @app.on_event("startup")
68
+ async def startup_event():
69
+ """Create necessary directories on startup."""
70
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
71
+ logger.info(f"Created output directory at {OUTPUT_DIR}")
72
+
73
+ # Root route serves the UI
74
+ @app.get("/", response_class=HTMLResponse)
75
+ async def root(request: Request):
76
+ """Serve the web UI interface as the default view."""
77
  try:
78
+ # Check if templates directory exists
79
+ if os.path.exists(templates_dir):
80
+ logger.info(f"Templates directory exists")
81
+ logger.info(f"Templates directory contents: {os.listdir(templates_dir)}")
82
+ else:
83
+ logger.warning(f"Templates directory does not exist: {templates_dir}")
 
 
84
 
85
+ # Try to render the template
86
+ try:
87
+ return templates.TemplateResponse("index.html", {"request": request})
88
+ except Exception as e:
89
+ logger.error(f"Error rendering template: {str(e)}")
90
+
91
+ # Try direct file reading as fallback
92
+ try:
93
+ with open(os.path.join(templates_dir, "index.html"), "r") as f:
94
+ html_content = f.read()
95
+ return HTMLResponse(content=html_content)
96
+ except Exception as file_e:
97
+ logger.error(f"Error reading file directly: {str(file_e)}")
98
+
99
+ # Last resort: try absolute path
100
+ try:
101
+ with open("/app/templates/index.html", "r") as f:
102
+ html_content = f.read()
103
+ return HTMLResponse(content=html_content)
104
+ except Exception as abs_e:
105
+ logger.error(f"Error reading file from absolute path: {str(abs_e)}")
106
+ raise HTTPException(status_code=500, detail=f"Error in UI endpoint: {str(e)}")
107
+ except Exception as outer_e:
108
+ logger.error(f"Outer error in UI endpoint: {str(outer_e)}")
109
+ raise HTTPException(status_code=500, detail=f"Error in UI endpoint: {str(outer_e)}")
110
 
111
+ # UI route for backward compatibility
112
+ @app.get("/ui", response_class=HTMLResponse)
113
+ async def ui(request: Request):
114
+ """Serve the web UI interface (kept for backward compatibility)."""
115
+ return await root(request)
116
 
117
+ # Status endpoint
118
+ @app.get("/status", response_model=StatusResponse)
119
+ async def get_status():
120
+ """Get the API status and version information."""
 
 
 
 
 
 
 
 
 
 
121
  return {
122
+ "status": "operational",
123
+ "version": "1.0.0",
124
+ "generator_version": "1.0.0",
125
  }
126
 
127
+ # Form-based generate endpoint for web UI
128
+ @app.post("/generate", response_class=HTMLResponse)
129
+ async def generate_form(
130
+ request: Request,
131
+ model_id: str = Form(...),
132
+ include_inference: bool = Form(False),
133
+ use_best_practices: bool = Form(True)
134
  ):
135
+ """Generate an AI SBOM from form data and render the result template."""
136
  try:
137
+ # Import the generator here to avoid circular imports
138
+ try:
139
+ from src.aibom_generator.generator import AIBOMGenerator
140
+ except ImportError:
141
+ try:
142
+ from aibom_generator.generator import AIBOMGenerator
143
+ except ImportError:
144
+ try:
145
+ from generator import AIBOMGenerator
146
+ except ImportError:
147
+ raise ImportError("Could not import AIBOMGenerator. Please check your installation.")
148
+
149
+ # Create generator instance
150
+ generator = AIBOMGenerator()
151
+
152
+ # Generate AIBOM
153
+ aibom, enhancement_report = generator.generate(
154
  model_id=model_id,
155
  include_inference=include_inference,
156
+ use_best_practices=use_best_practices
157
  )
158
 
159
+ # Save AIBOM to file
160
+ filename = f"{model_id.replace('/', '_')}_aibom.json"
161
+ filepath = os.path.join(OUTPUT_DIR, filename)
162
+ with open(filepath, "w") as f:
163
+ json.dump(aibom, f, indent=2)
164
+
165
+ # Create download URL
166
+ download_url = f"/output/{filename}"
167
+
168
+ # Create download script
169
+ download_script = """
170
+ <script>
171
+ function downloadJSON() {
172
+ const a = document.createElement('a');
173
+ a.href = '%s';
174
+ a.download = '%s';
175
+ document.body.appendChild(a);
176
+ a.click();
177
+ document.body.removeChild(a);
178
+ }
179
+ </script>
180
+ """ % (download_url, filename)
181
+
182
+ # Get completeness score
183
+ completeness_score = None
184
+ if hasattr(generator, 'get_completeness_score'):
185
+ try:
186
+ completeness_score = generator.get_completeness_score(model_id)
187
+ except Exception as e:
188
+ logger.error(f"Error getting completeness score: {str(e)}")
189
 
190
+ # Render result template
191
+ return templates.TemplateResponse(
192
+ "result.html",
193
+ {
194
+ "request": request,
195
+ "model_id": model_id,
196
+ "aibom": aibom,
197
+ "enhancement_report": enhancement_report,
198
+ "completeness_score": completeness_score,
199
+ "download_url": download_url,
200
+ "download_script": download_script
201
+ }
202
+ )
203
  except Exception as e:
204
+ logger.error(f"Error generating AI SBOM: {str(e)}")
205
+ return templates.TemplateResponse(
206
+ "error.html",
207
+ {
208
+ "request": request,
209
+ "error_message": f"Error generating AI SBOM: {str(e)}"
210
+ }
211
+ )
212
 
213
+ # JSON API endpoints
214
+ @app.post("/api/generate")
215
+ async def generate_api(request: GenerateRequest):
216
+ """Generate an AI SBOM and return it as JSON."""
217
+ try:
218
+ # Import the generator here to avoid circular imports
219
+ try:
220
+ from src.aibom_generator.generator import AIBOMGenerator
221
+ except ImportError:
222
+ try:
223
+ from aibom_generator.generator import AIBOMGenerator
224
+ except ImportError:
225
+ try:
226
+ from generator import AIBOMGenerator
227
+ except ImportError:
228
+ raise ImportError("Could not import AIBOMGenerator. Please check your installation.")
229
+
230
+ # Create generator instance
231
+ generator = AIBOMGenerator()
232
+
233
+ # Generate AIBOM
234
+ aibom, _ = generator.generate(
235
+ model_id=request.model_id,
236
+ include_inference=request.include_inference,
237
+ use_best_practices=request.use_best_practices
238
+ )
239
+
240
+ return aibom
241
+ except Exception as e:
242
+ raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}")
243
 
244
+ @app.post("/api/generate-with-report")
245
+ async def generate_with_report_api(request: GenerateWithReportRequest):
246
+ """Generate an AI SBOM with enhancement report and return both as JSON."""
247
+ try:
248
+ # Import the generator here to avoid circular imports
249
+ try:
250
+ from src.aibom_generator.generator import AIBOMGenerator
251
+ except ImportError:
252
+ try:
253
+ from aibom_generator.generator import AIBOMGenerator
254
+ except ImportError:
255
+ try:
256
+ from generator import AIBOMGenerator
257
+ except ImportError:
258
+ raise ImportError("Could not import AIBOMGenerator. Please check your installation.")
259
+
260
+ # Create generator instance
261
+ generator = AIBOMGenerator()
262
+
263
+ # Generate AIBOM
264
+ aibom, enhancement_report = generator.generate(
265
+ model_id=request.model_id,
266
+ include_inference=request.include_inference,
267
+ use_best_practices=request.use_best_practices
268
+ )
269
+
270
+ return {
271
+ "aibom": aibom,
272
+ "enhancement_report": enhancement_report
273
+ }
274
+ except Exception as e:
275
+ raise HTTPException(status_code=500, detail=f"Error generating AI SBOM: {str(e)}")
276
 
277
+ @app.get("/api/model-score/{model_id}")
278
+ async def model_score_api(model_id: str):
279
+ """Get the completeness score for a model."""
280
+ try:
281
+ # Import the generator here to avoid circular imports
282
+ try:
283
+ from src.aibom_generator.generator import AIBOMGenerator
284
+ except ImportError:
285
+ try:
286
+ from aibom_generator.generator import AIBOMGenerator
287
+ except ImportError:
288
+ try:
289
+ from generator import AIBOMGenerator
290
+ except ImportError:
291
+ raise ImportError("Could not import AIBOMGenerator. Please check your installation.")
292
+
293
+ # Create generator instance
294
+ generator = AIBOMGenerator()
295
+
296
+ # Get completeness score
297
+ if hasattr(generator, 'get_completeness_score'):
298
+ completeness_score = generator.get_completeness_score(model_id)
299
+ return completeness_score
300
+ else:
301
+ raise HTTPException(status_code=501, detail="Completeness score calculation not implemented")
302
+ except Exception as e:
303
+ raise HTTPException(status_code=500, detail=f"Error getting model score: {str(e)}")
304
 
305
+ @app.get("/download/{filename}")
306
+ async def download_file(filename: str):
307
+ """Download a generated AI SBOM file."""
308
+ try:
309
+ filepath = os.path.join(OUTPUT_DIR, filename)
310
+ if not os.path.exists(filepath):
311
+ raise HTTPException(status_code=404, detail=f"File {filename} not found")
312
+
313
+ return FileResponse(
314
+ filepath,
315
+ media_type="application/json",
316
+ filename=filename
317
+ )
318
+ except HTTPException:
319
+ raise
320
+ except Exception as e:
321
+ raise HTTPException(status_code=500, detail=f"Error downloading file: {str(e)}")