Spaces:
Running
Running
Update src/aibom_generator/api.py
Browse files- src/aibom_generator/api.py +95 -23
src/aibom_generator/api.py
CHANGED
@@ -13,6 +13,10 @@ from datasets import Dataset, load_dataset, concatenate_datasets
|
|
13 |
import os
|
14 |
import logging
|
15 |
from urllib.parse import urlparse
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO)
|
@@ -45,21 +49,51 @@ class StatusResponse(BaseModel):
|
|
45 |
version: str
|
46 |
generator_version: str
|
47 |
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def _normalise_model_id(raw_id: str) -> str:
|
50 |
"""
|
51 |
-
Accept either 'owner/model' or a full URL like
|
52 |
'https://huggingface.co/owner/model'. Return 'owner/model'.
|
|
|
53 |
"""
|
54 |
if raw_id.startswith(("http://", "https://") ):
|
55 |
path = urlparse(raw_id).path.lstrip("/")
|
56 |
-
# path can contain extra segments (e.g. /commit/...), keep first two
|
57 |
parts = path.split("/")
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
# --- End Model ID
|
|
|
63 |
|
64 |
# --- Add Counter Helper Functions ---
|
65 |
def log_sbom_generation(model_id: str):
|
@@ -370,9 +404,45 @@ async def generate_form(
|
|
370 |
use_best_practices: bool = Form(True)
|
371 |
):
|
372 |
sbom_count = get_sbom_count() # Get count early for context
|
373 |
-
|
374 |
-
|
375 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
try:
|
377 |
# Try different import paths for AIBOMGenerator
|
378 |
generator = None
|
@@ -391,32 +461,31 @@ async def generate_form(
|
|
391 |
logger.error("Could not import AIBOMGenerator from any known location")
|
392 |
raise ImportError("Could not import AIBOMGenerator from any known location")
|
393 |
|
394 |
-
# Generate AIBOM
|
395 |
aibom = generator.generate_aibom(
|
396 |
-
model_id=
|
397 |
include_inference=include_inference,
|
398 |
use_best_practices=use_best_practices
|
399 |
)
|
400 |
enhancement_report = generator.get_enhancement_report()
|
401 |
|
402 |
-
# Save AIBOM to file
|
403 |
# Corrected: Removed unnecessary backslashes around '/' and '_'
|
404 |
# Save AIBOM to file using normalized ID
|
405 |
-
filename = f"{normalized_model_id.replace('/', '_')}
|
406 |
filepath = os.path.join(OUTPUT_DIR, filename)
|
407 |
|
408 |
with open(filepath, "w") as f:
|
409 |
json.dump(aibom, f, indent=2)
|
410 |
|
411 |
# --- Log Generation Event ---
|
412 |
-
log_sbom_generation(
|
413 |
sbom_count = get_sbom_count() # Refresh count after logging
|
414 |
# --- End Log ---
|
415 |
|
416 |
download_url = f"/output/{filename}"
|
417 |
|
418 |
# Create download and UI interaction scripts
|
419 |
-
# Corrected: Removed unnecessary backslashes in script string
|
420 |
download_script = f"""
|
421 |
<script>
|
422 |
function downloadJSON() {{
|
@@ -461,17 +530,16 @@ async def generate_form(
|
|
461 |
"""
|
462 |
|
463 |
# Get completeness score or create a comprehensive one if not available
|
|
|
464 |
completeness_score = None
|
465 |
-
# Corrected: Removed unnecessary backslash in 'get_completeness_score'
|
466 |
if hasattr(generator, 'get_completeness_score'):
|
467 |
try:
|
468 |
-
completeness_score = generator.get_completeness_score(
|
469 |
logger.info("Successfully retrieved completeness_score from generator")
|
470 |
except Exception as e:
|
471 |
logger.error(f"Completeness score error from generator: {str(e)}")
|
472 |
|
473 |
# If completeness_score is None or doesn't have field_checklist, use comprehensive one
|
474 |
-
# Corrected: Removed unnecessary backslash in doesn't and 'field_checklist'
|
475 |
if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score:
|
476 |
logger.info("Using comprehensive completeness_score with field_checklist")
|
477 |
completeness_score = create_comprehensive_completeness_score(aibom)
|
@@ -537,15 +605,19 @@ async def generate_form(
|
|
537 |
"display_names": display_names,
|
538 |
"tooltips": tooltips,
|
539 |
"weights": weights,
|
540 |
-
"sbom_count": sbom_count
|
|
|
|
|
|
|
541 |
}
|
542 |
)
|
|
|
543 |
except Exception as e:
|
544 |
logger.error(f"Error generating AI SBOM: {str(e)}")
|
545 |
-
# Ensure count is passed to error template as well
|
546 |
sbom_count = get_sbom_count() # Refresh count just in case
|
|
|
547 |
return templates.TemplateResponse(
|
548 |
-
"error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id}
|
549 |
)
|
550 |
|
551 |
@app.get("/download/{filename}")
|
|
|
13 |
import os
|
14 |
import logging
|
15 |
from urllib.parse import urlparse
|
16 |
+
import re # Import regex module
|
17 |
+
import html # Import html module for escaping
|
18 |
+
from huggingface_hub import HfApi
|
19 |
+
from huggingface_hub.utils import RepositoryNotFoundError # For specific error handling
|
20 |
|
21 |
# Configure logging
|
22 |
logging.basicConfig(level=logging.INFO)
|
|
|
49 |
version: str
|
50 |
generator_version: str
|
51 |
|
52 |
+
|
53 |
+
# --- Model ID Validation and Normalization Helpers ---
|
54 |
+
# Regex for valid Hugging Face ID parts (alphanumeric, -, _, .)
|
55 |
+
# Allows owner/model format
|
56 |
+
HF_ID_REGEX = re.compile(r"^[a-zA-Z0-9\.\-\_]+/[a-zA-Z0-9\.\-\_]+$")
|
57 |
+
|
58 |
+
def is_valid_hf_input(input_str: str) -> bool:
|
59 |
+
"""Checks if the input is a valid Hugging Face model ID or URL."""
|
60 |
+
if not input_str or len(input_str) > 200: # Basic length check
|
61 |
+
return False
|
62 |
+
|
63 |
+
if input_str.startswith(("http://", "https://") ):
|
64 |
+
try:
|
65 |
+
parsed = urlparse(input_str)
|
66 |
+
# Check domain and path structure
|
67 |
+
if parsed.netloc == "huggingface.co":
|
68 |
+
path_parts = parsed.path.strip("/").split("/")
|
69 |
+
# Must have at least owner/model, can have more like /tree/main
|
70 |
+
if len(path_parts) >= 2 and path_parts[0] and path_parts[1]:
|
71 |
+
# Check characters in the relevant parts
|
72 |
+
if re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[0]) and \
|
73 |
+
re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[1]):
|
74 |
+
return True
|
75 |
+
return False # Not a valid HF URL format
|
76 |
+
except Exception:
|
77 |
+
return False # URL parsing failed
|
78 |
+
else:
|
79 |
+
# Assume owner/model format, check with regex
|
80 |
+
return bool(HF_ID_REGEX.match(input_str))
|
81 |
+
|
82 |
def _normalise_model_id(raw_id: str) -> str:
|
83 |
"""
|
84 |
+
Accept either validated 'owner/model' or a validated full URL like
|
85 |
'https://huggingface.co/owner/model'. Return 'owner/model'.
|
86 |
+
Assumes input has already been validated by is_valid_hf_input.
|
87 |
"""
|
88 |
if raw_id.startswith(("http://", "https://") ):
|
89 |
path = urlparse(raw_id).path.lstrip("/")
|
|
|
90 |
parts = path.split("/")
|
91 |
+
# We know from validation that parts[0] and parts[1] exist
|
92 |
+
return f"{parts[0]}/{parts[1]}"
|
93 |
+
return raw_id # Already in owner/model format
|
94 |
+
|
95 |
+
# --- End Model ID Helpers ---
|
96 |
+
|
97 |
|
98 |
# --- Add Counter Helper Functions ---
|
99 |
def log_sbom_generation(model_id: str):
|
|
|
404 |
use_best_practices: bool = Form(True)
|
405 |
):
|
406 |
sbom_count = get_sbom_count() # Get count early for context
|
407 |
+
|
408 |
+
|
409 |
+
# --- Input Sanitization ---
|
410 |
+
sanitized_model_id = html.escape(model_id)
|
411 |
+
|
412 |
+
# --- Input Format Validation ---
|
413 |
+
if not is_valid_hf_input(sanitized_model_id):
|
414 |
+
error_message = "Invalid input format. Please provide a valid Hugging Face model ID (e.g., 'owner/model') or a full model URL (e.g., 'https://huggingface.co/owner/model') ."
|
415 |
+
logger.warning(f"Invalid model input format received: {model_id}") # Log original input
|
416 |
+
# Try to display sanitized input in error message
|
417 |
+
return templates.TemplateResponse(
|
418 |
+
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": sanitized_model_id}
|
419 |
+
)
|
420 |
+
|
421 |
+
# --- Normalize the SANITIZED and VALIDATED model ID ---
|
422 |
+
normalized_model_id = _normalise_model_id(sanitized_model_id)
|
423 |
+
|
424 |
+
# --- Check if the ID corresponds to an actual HF Model ---
|
425 |
+
try:
|
426 |
+
hf_api = HfApi()
|
427 |
+
logger.info(f"Attempting to fetch model info for: {normalized_model_id}")
|
428 |
+
model_info = hf_api.model_info(normalized_model_id)
|
429 |
+
logger.info(f"Successfully fetched model info for: {normalized_model_id}")
|
430 |
+
except RepositoryNotFoundError:
|
431 |
+
error_message = f"Error: The provided ID \"{normalized_model_id}\" could not be found on Hugging Face or does not correspond to a model repository."
|
432 |
+
logger.warning(f"Repository not found for ID: {normalized_model_id}")
|
433 |
+
return templates.TemplateResponse(
|
434 |
+
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id}
|
435 |
+
)
|
436 |
+
except Exception as api_err: # Catch other potential API errors
|
437 |
+
error_message = f"Error verifying model ID with Hugging Face API: {str(api_err)}"
|
438 |
+
logger.error(f"HF API error for {normalized_model_id}: {str(api_err)}")
|
439 |
+
return templates.TemplateResponse(
|
440 |
+
"error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id}
|
441 |
+
)
|
442 |
+
# --- End Model Existence Check ---
|
443 |
+
|
444 |
+
|
445 |
+
# --- Main Generation Logic ---
|
446 |
try:
|
447 |
# Try different import paths for AIBOMGenerator
|
448 |
generator = None
|
|
|
461 |
logger.error("Could not import AIBOMGenerator from any known location")
|
462 |
raise ImportError("Could not import AIBOMGenerator from any known location")
|
463 |
|
464 |
+
# Generate AIBOM (pass SANITIZED ID)
|
465 |
aibom = generator.generate_aibom(
|
466 |
+
model_id=sanitized_model_id, # Use sanitized ID
|
467 |
include_inference=include_inference,
|
468 |
use_best_practices=use_best_practices
|
469 |
)
|
470 |
enhancement_report = generator.get_enhancement_report()
|
471 |
|
472 |
+
# Save AIBOM to file, use industry term ai_sbom in file name
|
473 |
# Corrected: Removed unnecessary backslashes around '/' and '_'
|
474 |
# Save AIBOM to file using normalized ID
|
475 |
+
filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json"
|
476 |
filepath = os.path.join(OUTPUT_DIR, filename)
|
477 |
|
478 |
with open(filepath, "w") as f:
|
479 |
json.dump(aibom, f, indent=2)
|
480 |
|
481 |
# --- Log Generation Event ---
|
482 |
+
log_sbom_generation(sanitized_model_id) # Use sanitized ID
|
483 |
sbom_count = get_sbom_count() # Refresh count after logging
|
484 |
# --- End Log ---
|
485 |
|
486 |
download_url = f"/output/{filename}"
|
487 |
|
488 |
# Create download and UI interaction scripts
|
|
|
489 |
download_script = f"""
|
490 |
<script>
|
491 |
function downloadJSON() {{
|
|
|
530 |
"""
|
531 |
|
532 |
# Get completeness score or create a comprehensive one if not available
|
533 |
+
# Use sanitized_model_id
|
534 |
completeness_score = None
|
|
|
535 |
if hasattr(generator, 'get_completeness_score'):
|
536 |
try:
|
537 |
+
completeness_score = generator.get_completeness_score(sanitized_model_id)
|
538 |
logger.info("Successfully retrieved completeness_score from generator")
|
539 |
except Exception as e:
|
540 |
logger.error(f"Completeness score error from generator: {str(e)}")
|
541 |
|
542 |
# If completeness_score is None or doesn't have field_checklist, use comprehensive one
|
|
|
543 |
if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score:
|
544 |
logger.info("Using comprehensive completeness_score with field_checklist")
|
545 |
completeness_score = create_comprehensive_completeness_score(aibom)
|
|
|
605 |
"display_names": display_names,
|
606 |
"tooltips": tooltips,
|
607 |
"weights": weights,
|
608 |
+
"sbom_count": sbom_count,
|
609 |
+
"display_names": display_names,
|
610 |
+
"tooltips": tooltips,
|
611 |
+
"weights": weights
|
612 |
}
|
613 |
)
|
614 |
+
# --- Main Exception Handling ---
|
615 |
except Exception as e:
|
616 |
logger.error(f"Error generating AI SBOM: {str(e)}")
|
|
|
617 |
sbom_count = get_sbom_count() # Refresh count just in case
|
618 |
+
# Pass count, added normalized model ID
|
619 |
return templates.TemplateResponse(
|
620 |
+
"error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id}
|
621 |
)
|
622 |
|
623 |
@app.get("/download/{filename}")
|