Spaces:
Running
Running
""" | |
Integration with the main generator class to use the inference model. | |
""" | |
import logging | |
from typing import Dict, List, Optional, Any | |
from huggingface_hub import ModelCard | |
from aibom_generator.inference import MetadataExtractor | |
from aibom_generator.utils import merge_metadata | |
logger = logging.getLogger(__name__) | |
class InferenceModelIntegration: | |
""" | |
Integration with the inference model for metadata extraction. | |
""" | |
def __init__( | |
self, | |
inference_url: Optional[str] = None, | |
use_inference: bool = True, | |
): | |
""" | |
Initialize the inference model integration. | |
Args: | |
inference_url: URL of the inference model service | |
use_inference: Whether to use the inference model | |
""" | |
self.extractor = MetadataExtractor(inference_url, use_inference) | |
def extract_metadata_from_model_card( | |
self, | |
model_card: ModelCard, | |
structured_metadata: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata from a model card using the inference model. | |
Args: | |
model_card: The ModelCard object | |
structured_metadata: Optional structured metadata to provide context | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
if not model_card: | |
logger.warning("No model card provided for inference extraction") | |
return {} | |
# Get the model card text content | |
model_card_text = model_card.text if hasattr(model_card, "text") else "" | |
if not model_card_text: | |
logger.warning("Model card has no text content for inference extraction") | |
return {} | |
# Extract metadata using the extractor | |
extracted_metadata = self.extractor.extract_metadata( | |
model_card_text, structured_metadata, fields | |
) | |
return extracted_metadata | |
def enhance_metadata( | |
self, | |
structured_metadata: Dict[str, Any], | |
model_card: ModelCard, | |
) -> Dict[str, Any]: | |
""" | |
Enhance structured metadata with information extracted from the model card. | |
Args: | |
structured_metadata: Structured metadata from API | |
model_card: The ModelCard object | |
Returns: | |
Enhanced metadata as a dictionary | |
""" | |
# Identify missing fields that could be extracted from unstructured text | |
missing_fields = self._identify_missing_fields(structured_metadata) | |
if not missing_fields: | |
logger.info("No missing fields to extract from unstructured text") | |
return structured_metadata | |
# Extract missing fields from unstructured text | |
extracted_metadata = self.extract_metadata_from_model_card( | |
model_card, structured_metadata, missing_fields | |
) | |
# Merge the extracted metadata with the structured metadata | |
# Structured metadata takes precedence | |
enhanced_metadata = merge_metadata(structured_metadata, extracted_metadata) | |
return enhanced_metadata | |
def _identify_missing_fields(self, metadata: Dict[str, Any]) -> List[str]: | |
""" | |
Identify fields that are missing or incomplete in the metadata. | |
Args: | |
metadata: The metadata to check | |
Returns: | |
List of missing field names | |
""" | |
missing_fields = [] | |
# Check for missing or empty fields | |
important_fields = [ | |
"description", | |
"license", | |
"model_parameters", | |
"datasets", | |
"evaluation_results", | |
"limitations", | |
"ethical_considerations", | |
] | |
for field in important_fields: | |
if field not in metadata or not metadata[field]: | |
missing_fields.append(field) | |
elif isinstance(metadata[field], dict) and not any(metadata[field].values()): | |
missing_fields.append(field) | |
elif isinstance(metadata[field], list) and not metadata[field]: | |
missing_fields.append(field) | |
return missing_fields | |