""" Inference model implementation for metadata extraction from model cards. This module provides a fine-tuned model for extracting structured metadata from unstructured text in Hugging Face model cards. """ import json import logging import os import re import torch from typing import Dict, List, Optional, Any, Union from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline from transformers import AutoModelForSeq2SeqLM, T5Tokenizer logger = logging.getLogger(__name__) class ModelCardExtractor: """ Fine-tuned model for extracting metadata from model card text. """ def __init__( self, model_name: str = "distilbert-base-uncased", device: str = "cpu", max_length: int = 512, cache_dir: Optional[str] = None, ): """ Initialize the model card extractor. Args: model_name: Name or path of the pre-trained model device: Device to run the model on ('cpu' or 'cuda') max_length: Maximum sequence length for tokenization cache_dir: Directory to cache models """ self.model_name = model_name self.device = device self.max_length = max_length self.cache_dir = cache_dir # Load tokenizer and model self.tokenizer = None self.model = None # Initialize extraction pipelines self.section_classifier = None self.metadata_extractor = None # Load models self._load_models() def _load_models(self): """Load the required models for extraction.""" try: # Load section classifier logger.info(f"Loading section classifier model: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, cache_dir=self.cache_dir, ) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, cache_dir=self.cache_dir, ) self.model.to(self.device) # Create section classification pipeline self.section_classifier = pipeline( "text-classification", model=self.model, tokenizer=self.tokenizer, device=0 if self.device == "cuda" else -1, ) # For demonstration purposes, we'll use a T5-based model for extraction # In a real implementation, this would be a fine-tuned model specific to the task logger.info("Loading metadata extraction model") extraction_model_name = "t5-small" # Placeholder for fine-tuned model self.extraction_tokenizer = T5Tokenizer.from_pretrained( extraction_model_name, cache_dir=self.cache_dir, ) self.extraction_model = AutoModelForSeq2SeqLM.from_pretrained( extraction_model_name, cache_dir=self.cache_dir, ) self.extraction_model.to(self.device) logger.info("Models loaded successfully") except Exception as e: logger.error(f"Error loading models: {e}") raise def extract_metadata( self, text: str, fields: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Extract metadata from model card text. Args: text: The model card text fields: Optional list of specific fields to extract Returns: Extracted metadata as a dictionary """ # Split text into sections sections = self._split_into_sections(text) # Classify sections classified_sections = self._classify_sections(sections) # Extract metadata from each section metadata = {} for section_type, section_text in classified_sections.items(): if fields and section_type not in fields: continue extracted = self._extract_from_section(section_type, section_text) if extracted: metadata[section_type] = extracted return metadata def _split_into_sections(self, text: str) -> List[Dict[str, str]]: """ Split the model card text into sections. Args: text: The model card text Returns: List of sections with title and content """ # Simple section splitting based on headers # In a real implementation, this would be more sophisticated sections = [] # Match markdown headers (# Header, ## Header, etc.) header_pattern = r"(?:^|\n)(#+)\s+(.*?)(?:\n|$)" # Find all headers headers = list(re.finditer(header_pattern, text)) for i, match in enumerate(headers): header_level = len(match.group(1)) header_text = match.group(2).strip() start = match.end() # Find the end of the section (next header or end of text) if i < len(headers) - 1: end = headers[i + 1].start() else: end = len(text) # Extract the section content content = text[start:end].strip() sections.append({ "title": header_text, "level": header_level, "content": content, }) # If no sections were found, treat the entire text as one section if not sections: sections.append({ "title": "Main", "level": 1, "content": text.strip(), }) return sections def _classify_sections(self, sections: List[Dict[str, str]]) -> Dict[str, str]: """ Classify sections into metadata categories. Args: sections: List of sections with title and content Returns: Dictionary mapping section types to section content """ classified = {} # Map common section titles to metadata fields title_mappings = { "model description": "description", "description": "description", "model details": "model_parameters", "model architecture": "model_parameters", "parameters": "model_parameters", "training data": "datasets", "dataset": "datasets", "datasets": "datasets", "training": "training_procedure", "evaluation": "evaluation_results", "results": "evaluation_results", "performance": "evaluation_results", "metrics": "evaluation_results", "limitations": "limitations", "biases": "ethical_considerations", "bias": "ethical_considerations", "ethical considerations": "ethical_considerations", "ethics": "ethical_considerations", "risks": "ethical_considerations", "license": "license", "citation": "citation", "references": "citation", } for section in sections: title = section["title"].lower() content = section["content"] # Check for direct title matches matched = False for key, value in title_mappings.items(): if key in title: if value not in classified: classified[value] = content else: classified[value] += "\n\n" + content matched = True break # If no match by title, use the classifier if not matched and self.section_classifier and len(content.split()) > 5: try: # This is a placeholder for actual classification # In a real implementation, this would use the fine-tuned classifier section_type = self._classify_text(content) if section_type and section_type not in classified: classified[section_type] = content elif section_type: classified[section_type] += "\n\n" + content except Exception as e: logger.error(f"Error classifying section: {e}") return classified def _classify_text(self, text: str) -> Optional[str]: """ Classify text into a metadata category. Args: text: The text to classify Returns: Metadata category or None if classification fails """ # This is a placeholder for actual classification # In a real implementation, this would use the fine-tuned classifier # Simple keyword-based classification for demonstration keywords = { "description": ["is a", "this model", "based on", "pretrained"], "model_parameters": ["parameters", "layers", "hidden", "dimension", "architecture"], "datasets": ["dataset", "corpus", "trained on", "fine-tuned on"], "evaluation_results": ["accuracy", "f1", "precision", "recall", "performance"], "limitations": ["limitation", "limited", "does not", "cannot", "fails to"], "ethical_considerations": ["bias", "ethical", "fairness", "gender", "race"], } # Count keyword occurrences counts = {category: 0 for category in keywords} for category, words in keywords.items(): for word in words: counts[category] += len(re.findall(r'\b' + re.escape(word) + r'\b', text.lower())) # Return the category with the most keyword matches if counts: max_category = max(counts.items(), key=lambda x: x[1]) if max_category[1] > 0: return max_category[0] return None def _extract_from_section(self, section_type: str, text: str) -> Any: """ Extract structured metadata from a section. Args: section_type: The type of section text: The section text Returns: Extracted metadata """ # This is a placeholder for actual extraction # In a real implementation, this would use the fine-tuned extraction model if section_type == "description": # Simply return the text for description return text.strip() elif section_type == "model_parameters": # Extract model parameters using regex params = {} # Extract architecture arch_match = re.search(r'(?:architecture|model type|based on)[:\s]+([A-Za-z0-9\-]+)', text, re.IGNORECASE) if arch_match: params["architecture"] = arch_match.group(1).strip() # Extract parameter count param_match = re.search(r'(\d+(?:\.\d+)?)\s*(?:B|M|K)?\s*(?:billion|million|thousand)?\s*parameters', text, re.IGNORECASE) if param_match: params["parameter_count"] = param_match.group(1).strip() return params elif section_type == "datasets": # Extract dataset names datasets = [] dataset_patterns = [ r'trained on\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)?', r'dataset[:\s]+([A-Za-z0-9\-\s]+)', r'using\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)', ] for pattern in dataset_patterns: for match in re.finditer(pattern, text, re.IGNORECASE): dataset = match.group(1).strip() if dataset and dataset.lower() not in ["this", "these", "those"]: datasets.append(dataset) return list(set(datasets)) elif section_type == "evaluation_results": # Extract evaluation metrics results = {} # Extract accuracy acc_match = re.search(r'accuracy[:\s]+(\d+(?:\.\d+)?)\s*%?', text, re.IGNORECASE) if acc_match: results["accuracy"] = float(acc_match.group(1)) # Extract F1 score f1_match = re.search(r'f1(?:\s*[\-_]?score)?[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) if f1_match: results["f1"] = float(f1_match.group(1)) # Extract precision prec_match = re.search(r'precision[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) if prec_match: results["precision"] = float(prec_match.group(1)) # Extract recall recall_match = re.search(r'recall[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) if recall_match: results["recall"] = float(recall_match.group(1)) return results elif section_type == "limitations": # Simply return the text for limitations return text.strip() elif section_type == "ethical_considerations": # Simply return the text for ethical considerations return text.strip() elif section_type == "license": # Extract license information license_match = re.search(r'(?:license|licensing)[:\s]+([A-Za-z0-9\-\s]+)', text, re.IGNORECASE) if license_match: return license_match.group(1).strip() return text.strip() elif section_type == "citation": # Simply return the text for citation return text.strip() # Default case return text.strip() class InferenceModelServer: """ Server for the inference model. This class provides a server for the inference model that can be deployed as a standalone service with a REST API. """ def __init__( self, model_name: str = "distilbert-base-uncased", device: str = "cpu", cache_dir: Optional[str] = None, ): """ Initialize the inference model server. Args: model_name: Name or path of the pre-trained model device: Device to run the model on ('cpu' or 'cuda') cache_dir: Directory to cache models """ self.extractor = ModelCardExtractor( model_name=model_name, device=device, cache_dir=cache_dir, ) def extract_metadata( self, text: str, structured_metadata: Optional[Dict[str, Any]] = None, fields: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Extract metadata from model card text. Args: text: The model card text structured_metadata: Optional structured metadata to provide context fields: Optional list of specific fields to extract Returns: Extracted metadata as a dictionary """ try: # Extract metadata using the extractor metadata = self.extractor.extract_metadata(text, fields) # Enhance with structured metadata if provided if structured_metadata: # Use structured metadata for fields not extracted for key, value in structured_metadata.items(): if key not in metadata or not metadata[key]: metadata[key] = value return {"metadata": metadata, "success": True} except Exception as e: logger.error(f"Error extracting metadata: {e}") return {"metadata": {}, "success": False, "error": str(e)} def create_app(model_name: str = "distilbert-base-uncased", device: str = "cpu"): """ Create a Flask app for the inference model server. Args: model_name: Name or path of the pre-trained model device: Device to run the model on ('cpu' or 'cuda') Returns: Flask app """ from flask import Flask, request, jsonify app = Flask(__name__) server = InferenceModelServer(model_name=model_name, device=device) @app.route("/extract", methods=["POST"]) def extract(): data = request.json text = data.get("text", "") structured_metadata = data.get("structured_metadata", {}) fields = data.get("fields", []) result = server.extract_metadata(text, structured_metadata, fields) return jsonify(result) @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "healthy"}) return app def main(): """Main entry point for the inference model server.""" import argparse parser = argparse.ArgumentParser( description="Start the inference model server for AIBOM metadata extraction." ) parser.add_argument( "--model", help="Name or path of the pre-trained model", default="distilbert-base-uncased", ) parser.add_argument( "--device", help="Device to run the model on ('cpu' or 'cuda')", choices=["cpu", "cuda"], default="cpu", ) parser.add_argument( "--host", help="Host to bind the server to", default="0.0.0.0", ) parser.add_argument( "--port", help="Port to bind the server to", type=int, default=5000, ) parser.add_argument( "--debug", help="Enable debug mode", action="store_true", ) args = parser.parse_args() # Create and run the app app = create_app(model_name=args.model, device=args.device) app.run(host=args.host, port=args.port, debug=args.debug) if __name__ == "__main__": main()