a1c00l commited on
Commit
4f95bff
·
verified ·
1 Parent(s): f0336d9

Update src/aibom_generator/api.py

Browse files
Files changed (1) hide show
  1. src/aibom_generator/api.py +95 -23
src/aibom_generator/api.py CHANGED
@@ -13,6 +13,10 @@ from datasets import Dataset, load_dataset, concatenate_datasets
13
  import os
14
  import logging
15
  from urllib.parse import urlparse
 
 
 
 
16
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO)
@@ -45,21 +49,51 @@ class StatusResponse(BaseModel):
45
  version: str
46
  generator_version: str
47
 
48
- # --- Add Model ID Normalization Helper ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def _normalise_model_id(raw_id: str) -> str:
50
  """
51
- Accept either 'owner/model' or a full URL like
52
  'https://huggingface.co/owner/model'. Return 'owner/model'.
 
53
  """
54
  if raw_id.startswith(("http://", "https://") ):
55
  path = urlparse(raw_id).path.lstrip("/")
56
- # path can contain extra segments (e.g. /commit/...), keep first two
57
  parts = path.split("/")
58
- if len(parts) >= 2:
59
- return "/".join(parts[:2])
60
- return path # Fallback if path doesn't have owner/model
61
- return raw_id
62
- # --- End Model ID Normalization Helper ---
 
63
 
64
  # --- Add Counter Helper Functions ---
65
  def log_sbom_generation(model_id: str):
@@ -370,9 +404,45 @@ async def generate_form(
370
  use_best_practices: bool = Form(True)
371
  ):
372
  sbom_count = get_sbom_count() # Get count early for context
373
- # --- Normalize the model ID for display and filename ---
374
- normalized_model_id = _normalise_model_id(model_id)
375
- # --- End Normalization ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  try:
377
  # Try different import paths for AIBOMGenerator
378
  generator = None
@@ -391,32 +461,31 @@ async def generate_form(
391
  logger.error("Could not import AIBOMGenerator from any known location")
392
  raise ImportError("Could not import AIBOMGenerator from any known location")
393
 
394
- # Generate AIBOM
395
  aibom = generator.generate_aibom(
396
- model_id=model_id,
397
  include_inference=include_inference,
398
  use_best_practices=use_best_practices
399
  )
400
  enhancement_report = generator.get_enhancement_report()
401
 
402
- # Save AIBOM to file
403
  # Corrected: Removed unnecessary backslashes around '/' and '_'
404
  # Save AIBOM to file using normalized ID
405
- filename = f"{normalized_model_id.replace('/', '_')}_aibom.json"
406
  filepath = os.path.join(OUTPUT_DIR, filename)
407
 
408
  with open(filepath, "w") as f:
409
  json.dump(aibom, f, indent=2)
410
 
411
  # --- Log Generation Event ---
412
- log_sbom_generation(model_id)
413
  sbom_count = get_sbom_count() # Refresh count after logging
414
  # --- End Log ---
415
 
416
  download_url = f"/output/{filename}"
417
 
418
  # Create download and UI interaction scripts
419
- # Corrected: Removed unnecessary backslashes in script string
420
  download_script = f"""
421
  <script>
422
  function downloadJSON() {{
@@ -461,17 +530,16 @@ async def generate_form(
461
  """
462
 
463
  # Get completeness score or create a comprehensive one if not available
 
464
  completeness_score = None
465
- # Corrected: Removed unnecessary backslash in 'get_completeness_score'
466
  if hasattr(generator, 'get_completeness_score'):
467
  try:
468
- completeness_score = generator.get_completeness_score(model_id)
469
  logger.info("Successfully retrieved completeness_score from generator")
470
  except Exception as e:
471
  logger.error(f"Completeness score error from generator: {str(e)}")
472
 
473
  # If completeness_score is None or doesn't have field_checklist, use comprehensive one
474
- # Corrected: Removed unnecessary backslash in doesn't and 'field_checklist'
475
  if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score:
476
  logger.info("Using comprehensive completeness_score with field_checklist")
477
  completeness_score = create_comprehensive_completeness_score(aibom)
@@ -537,15 +605,19 @@ async def generate_form(
537
  "display_names": display_names,
538
  "tooltips": tooltips,
539
  "weights": weights,
540
- "sbom_count": sbom_count # Pass count
 
 
 
541
  }
542
  )
 
543
  except Exception as e:
544
  logger.error(f"Error generating AI SBOM: {str(e)}")
545
- # Ensure count is passed to error template as well
546
  sbom_count = get_sbom_count() # Refresh count just in case
 
547
  return templates.TemplateResponse(
548
- "error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id} # Pass count, added normalized model ID
549
  )
550
 
551
  @app.get("/download/{filename}")
 
13
  import os
14
  import logging
15
  from urllib.parse import urlparse
16
+ import re # Import regex module
17
+ import html # Import html module for escaping
18
+ from huggingface_hub import HfApi
19
+ from huggingface_hub.utils import RepositoryNotFoundError # For specific error handling
20
 
21
  # Configure logging
22
  logging.basicConfig(level=logging.INFO)
 
49
  version: str
50
  generator_version: str
51
 
52
+
53
+ # --- Model ID Validation and Normalization Helpers ---
54
+ # Regex for valid Hugging Face ID parts (alphanumeric, -, _, .)
55
+ # Allows owner/model format
56
+ HF_ID_REGEX = re.compile(r"^[a-zA-Z0-9\.\-\_]+/[a-zA-Z0-9\.\-\_]+$")
57
+
58
+ def is_valid_hf_input(input_str: str) -> bool:
59
+ """Checks if the input is a valid Hugging Face model ID or URL."""
60
+ if not input_str or len(input_str) > 200: # Basic length check
61
+ return False
62
+
63
+ if input_str.startswith(("http://", "https://") ):
64
+ try:
65
+ parsed = urlparse(input_str)
66
+ # Check domain and path structure
67
+ if parsed.netloc == "huggingface.co":
68
+ path_parts = parsed.path.strip("/").split("/")
69
+ # Must have at least owner/model, can have more like /tree/main
70
+ if len(path_parts) >= 2 and path_parts[0] and path_parts[1]:
71
+ # Check characters in the relevant parts
72
+ if re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[0]) and \
73
+ re.match(r"^[a-zA-Z0-9\.\-\_]+$", path_parts[1]):
74
+ return True
75
+ return False # Not a valid HF URL format
76
+ except Exception:
77
+ return False # URL parsing failed
78
+ else:
79
+ # Assume owner/model format, check with regex
80
+ return bool(HF_ID_REGEX.match(input_str))
81
+
82
  def _normalise_model_id(raw_id: str) -> str:
83
  """
84
+ Accept either validated 'owner/model' or a validated full URL like
85
  'https://huggingface.co/owner/model'. Return 'owner/model'.
86
+ Assumes input has already been validated by is_valid_hf_input.
87
  """
88
  if raw_id.startswith(("http://", "https://") ):
89
  path = urlparse(raw_id).path.lstrip("/")
 
90
  parts = path.split("/")
91
+ # We know from validation that parts[0] and parts[1] exist
92
+ return f"{parts[0]}/{parts[1]}"
93
+ return raw_id # Already in owner/model format
94
+
95
+ # --- End Model ID Helpers ---
96
+
97
 
98
  # --- Add Counter Helper Functions ---
99
  def log_sbom_generation(model_id: str):
 
404
  use_best_practices: bool = Form(True)
405
  ):
406
  sbom_count = get_sbom_count() # Get count early for context
407
+
408
+
409
+ # --- Input Sanitization ---
410
+ sanitized_model_id = html.escape(model_id)
411
+
412
+ # --- Input Format Validation ---
413
+ if not is_valid_hf_input(sanitized_model_id):
414
+ error_message = "Invalid input format. Please provide a valid Hugging Face model ID (e.g., 'owner/model') or a full model URL (e.g., 'https://huggingface.co/owner/model') ."
415
+ logger.warning(f"Invalid model input format received: {model_id}") # Log original input
416
+ # Try to display sanitized input in error message
417
+ return templates.TemplateResponse(
418
+ "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": sanitized_model_id}
419
+ )
420
+
421
+ # --- Normalize the SANITIZED and VALIDATED model ID ---
422
+ normalized_model_id = _normalise_model_id(sanitized_model_id)
423
+
424
+ # --- Check if the ID corresponds to an actual HF Model ---
425
+ try:
426
+ hf_api = HfApi()
427
+ logger.info(f"Attempting to fetch model info for: {normalized_model_id}")
428
+ model_info = hf_api.model_info(normalized_model_id)
429
+ logger.info(f"Successfully fetched model info for: {normalized_model_id}")
430
+ except RepositoryNotFoundError:
431
+ error_message = f"Error: The provided ID \"{normalized_model_id}\" could not be found on Hugging Face or does not correspond to a model repository."
432
+ logger.warning(f"Repository not found for ID: {normalized_model_id}")
433
+ return templates.TemplateResponse(
434
+ "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id}
435
+ )
436
+ except Exception as api_err: # Catch other potential API errors
437
+ error_message = f"Error verifying model ID with Hugging Face API: {str(api_err)}"
438
+ logger.error(f"HF API error for {normalized_model_id}: {str(api_err)}")
439
+ return templates.TemplateResponse(
440
+ "error.html", {"request": request, "error": error_message, "sbom_count": sbom_count, "model_id": normalized_model_id}
441
+ )
442
+ # --- End Model Existence Check ---
443
+
444
+
445
+ # --- Main Generation Logic ---
446
  try:
447
  # Try different import paths for AIBOMGenerator
448
  generator = None
 
461
  logger.error("Could not import AIBOMGenerator from any known location")
462
  raise ImportError("Could not import AIBOMGenerator from any known location")
463
 
464
+ # Generate AIBOM (pass SANITIZED ID)
465
  aibom = generator.generate_aibom(
466
+ model_id=sanitized_model_id, # Use sanitized ID
467
  include_inference=include_inference,
468
  use_best_practices=use_best_practices
469
  )
470
  enhancement_report = generator.get_enhancement_report()
471
 
472
+ # Save AIBOM to file, use industry term ai_sbom in file name
473
  # Corrected: Removed unnecessary backslashes around '/' and '_'
474
  # Save AIBOM to file using normalized ID
475
+ filename = f"{normalized_model_id.replace('/', '_')}_ai_sbom.json"
476
  filepath = os.path.join(OUTPUT_DIR, filename)
477
 
478
  with open(filepath, "w") as f:
479
  json.dump(aibom, f, indent=2)
480
 
481
  # --- Log Generation Event ---
482
+ log_sbom_generation(sanitized_model_id) # Use sanitized ID
483
  sbom_count = get_sbom_count() # Refresh count after logging
484
  # --- End Log ---
485
 
486
  download_url = f"/output/{filename}"
487
 
488
  # Create download and UI interaction scripts
 
489
  download_script = f"""
490
  <script>
491
  function downloadJSON() {{
 
530
  """
531
 
532
  # Get completeness score or create a comprehensive one if not available
533
+ # Use sanitized_model_id
534
  completeness_score = None
 
535
  if hasattr(generator, 'get_completeness_score'):
536
  try:
537
+ completeness_score = generator.get_completeness_score(sanitized_model_id)
538
  logger.info("Successfully retrieved completeness_score from generator")
539
  except Exception as e:
540
  logger.error(f"Completeness score error from generator: {str(e)}")
541
 
542
  # If completeness_score is None or doesn't have field_checklist, use comprehensive one
 
543
  if completeness_score is None or not isinstance(completeness_score, dict) or 'field_checklist' not in completeness_score:
544
  logger.info("Using comprehensive completeness_score with field_checklist")
545
  completeness_score = create_comprehensive_completeness_score(aibom)
 
605
  "display_names": display_names,
606
  "tooltips": tooltips,
607
  "weights": weights,
608
+ "sbom_count": sbom_count,
609
+ "display_names": display_names,
610
+ "tooltips": tooltips,
611
+ "weights": weights
612
  }
613
  )
614
+ # --- Main Exception Handling ---
615
  except Exception as e:
616
  logger.error(f"Error generating AI SBOM: {str(e)}")
 
617
  sbom_count = get_sbom_count() # Refresh count just in case
618
+ # Pass count, added normalized model ID
619
  return templates.TemplateResponse(
620
+ "error.html", {"request": request, "error": str(e), "sbom_count": sbom_count, "model_id": normalized_model_id}
621
  )
622
 
623
  @app.get("/download/{filename}")