daetheris commited on
Commit
c8175b6
·
verified ·
1 Parent(s): d6599d0

Attempt fixing the purl

Browse files
Files changed (1) hide show
  1. src/aibom_generator/generator.py +21 -0
src/aibom_generator/generator.py CHANGED
@@ -3,7 +3,9 @@ import uuid
3
  import datetime
4
  from typing import Dict, Optional, Any, List
5
 
 
6
  from huggingface_hub import HfApi, ModelCard
 
7
  from .utils import calculate_completeness_score
8
 
9
 
@@ -31,6 +33,7 @@ class AIBOMGenerator:
31
  use_best_practices: Optional[bool] = None, # Added parameter for industry-neutral scoring
32
  ) -> Dict[str, Any]:
33
  try:
 
34
  use_inference = include_inference if include_inference is not None else self.use_inference
35
  # Use method parameter if provided, otherwise use instance variable
36
  use_best_practices = use_best_practices if use_best_practices is not None else self.use_best_practices
@@ -159,6 +162,23 @@ class AIBOMGenerator:
159
  print(f"Error fetching model info for {model_id}: {e}")
160
  return {}
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def _fetch_model_card(self, model_id: str) -> Optional[ModelCard]:
163
  try:
164
  return ModelCard.load(model_id)
@@ -588,3 +608,4 @@ class AIBOMGenerator:
588
  }
589
 
590
  return license_urls.get(license_id, "https://spdx.org/licenses/")
 
 
3
  import datetime
4
  from typing import Dict, Optional, Any, List
5
 
6
+
7
  from huggingface_hub import HfApi, ModelCard
8
+ from urllib.parse import urlparse
9
  from .utils import calculate_completeness_score
10
 
11
 
 
33
  use_best_practices: Optional[bool] = None, # Added parameter for industry-neutral scoring
34
  ) -> Dict[str, Any]:
35
  try:
36
+ model_id = self._normalise_model_id(model_id)
37
  use_inference = include_inference if include_inference is not None else self.use_inference
38
  # Use method parameter if provided, otherwise use instance variable
39
  use_best_practices = use_best_practices if use_best_practices is not None else self.use_best_practices
 
162
  print(f"Error fetching model info for {model_id}: {e}")
163
  return {}
164
 
165
+ # ---- new helper ---------------------------------------------------------
166
+ @staticmethod
167
+ def _normalise_model_id(raw_id: str) -> str:
168
+ """
169
+ Accept either 'owner/model' or a full URL like
170
+ 'https://huggingface.co/owner/model'. Return 'owner/model'.
171
+ """
172
+ if raw_id.startswith(("http://", "https://")):
173
+ path = urlparse(raw_id).path.lstrip("/")
174
+ # path can contain extra segments (e.g. /commit/...), keep first two
175
+ parts = path.split("/")
176
+ if len(parts) >= 2:
177
+ return "/".join(parts[:2])
178
+ return path
179
+ return raw_id
180
+ # -------------------------------------------------------------------------
181
+
182
  def _fetch_model_card(self, model_id: str) -> Optional[ModelCard]:
183
  try:
184
  return ModelCard.load(model_id)
 
608
  }
609
 
610
  return license_urls.get(license_id, "https://spdx.org/licenses/")
611
+