a1c00l's picture
Upload 9 files
8819832 verified
raw
history blame
4.42 kB
"""
Integration with the main generator class to use the inference model.
"""
import logging
from typing import Dict, List, Optional, Any
from huggingface_hub import ModelCard
from aibom_generator.inference import MetadataExtractor
from aibom_generator.utils import merge_metadata
logger = logging.getLogger(__name__)
class InferenceModelIntegration:
"""
Integration with the inference model for metadata extraction.
"""
def __init__(
self,
inference_url: Optional[str] = None,
use_inference: bool = True,
):
"""
Initialize the inference model integration.
Args:
inference_url: URL of the inference model service
use_inference: Whether to use the inference model
"""
self.extractor = MetadataExtractor(inference_url, use_inference)
def extract_metadata_from_model_card(
self,
model_card: ModelCard,
structured_metadata: Optional[Dict[str, Any]] = None,
fields: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Extract metadata from a model card using the inference model.
Args:
model_card: The ModelCard object
structured_metadata: Optional structured metadata to provide context
fields: Optional list of specific fields to extract
Returns:
Extracted metadata as a dictionary
"""
if not model_card:
logger.warning("No model card provided for inference extraction")
return {}
# Get the model card text content
model_card_text = model_card.text if hasattr(model_card, "text") else ""
if not model_card_text:
logger.warning("Model card has no text content for inference extraction")
return {}
# Extract metadata using the extractor
extracted_metadata = self.extractor.extract_metadata(
model_card_text, structured_metadata, fields
)
return extracted_metadata
def enhance_metadata(
self,
structured_metadata: Dict[str, Any],
model_card: ModelCard,
) -> Dict[str, Any]:
"""
Enhance structured metadata with information extracted from the model card.
Args:
structured_metadata: Structured metadata from API
model_card: The ModelCard object
Returns:
Enhanced metadata as a dictionary
"""
# Identify missing fields that could be extracted from unstructured text
missing_fields = self._identify_missing_fields(structured_metadata)
if not missing_fields:
logger.info("No missing fields to extract from unstructured text")
return structured_metadata
# Extract missing fields from unstructured text
extracted_metadata = self.extract_metadata_from_model_card(
model_card, structured_metadata, missing_fields
)
# Merge the extracted metadata with the structured metadata
# Structured metadata takes precedence
enhanced_metadata = merge_metadata(structured_metadata, extracted_metadata)
return enhanced_metadata
def _identify_missing_fields(self, metadata: Dict[str, Any]) -> List[str]:
"""
Identify fields that are missing or incomplete in the metadata.
Args:
metadata: The metadata to check
Returns:
List of missing field names
"""
missing_fields = []
# Check for missing or empty fields
important_fields = [
"description",
"license",
"model_parameters",
"datasets",
"evaluation_results",
"limitations",
"ethical_considerations",
]
for field in important_fields:
if field not in metadata or not metadata[field]:
missing_fields.append(field)
elif isinstance(metadata[field], dict) and not any(metadata[field].values()):
missing_fields.append(field)
elif isinstance(metadata[field], list) and not metadata[field]:
missing_fields.append(field)
return missing_fields