Spaces:
Running
Running
""" | |
Inference model implementation for metadata extraction from model cards. | |
This module provides a fine-tuned model for extracting structured metadata | |
from unstructured text in Hugging Face model cards. | |
""" | |
import json | |
import logging | |
import os | |
import re | |
import torch | |
from typing import Dict, List, Optional, Any, Union | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer | |
logger = logging.getLogger(__name__) | |
class ModelCardExtractor: | |
""" | |
Fine-tuned model for extracting metadata from model card text. | |
""" | |
def __init__( | |
self, | |
model_name: str = "distilbert-base-uncased", | |
device: str = "cpu", | |
max_length: int = 512, | |
cache_dir: Optional[str] = None, | |
): | |
""" | |
Initialize the model card extractor. | |
Args: | |
model_name: Name or path of the pre-trained model | |
device: Device to run the model on ('cpu' or 'cuda') | |
max_length: Maximum sequence length for tokenization | |
cache_dir: Directory to cache models | |
""" | |
self.model_name = model_name | |
self.device = device | |
self.max_length = max_length | |
self.cache_dir = cache_dir | |
# Load tokenizer and model | |
self.tokenizer = None | |
self.model = None | |
# Initialize extraction pipelines | |
self.section_classifier = None | |
self.metadata_extractor = None | |
# Load models | |
self._load_models() | |
def _load_models(self): | |
"""Load the required models for extraction.""" | |
try: | |
# Load section classifier | |
logger.info(f"Loading section classifier model: {self.model_name}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
cache_dir=self.cache_dir, | |
) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name, | |
cache_dir=self.cache_dir, | |
) | |
self.model.to(self.device) | |
# Create section classification pipeline | |
self.section_classifier = pipeline( | |
"text-classification", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device=0 if self.device == "cuda" else -1, | |
) | |
# For demonstration purposes, we'll use a T5-based model for extraction | |
# In a real implementation, this would be a fine-tuned model specific to the task | |
logger.info("Loading metadata extraction model") | |
extraction_model_name = "t5-small" # Placeholder for fine-tuned model | |
self.extraction_tokenizer = T5Tokenizer.from_pretrained( | |
extraction_model_name, | |
cache_dir=self.cache_dir, | |
) | |
self.extraction_model = AutoModelForSeq2SeqLM.from_pretrained( | |
extraction_model_name, | |
cache_dir=self.cache_dir, | |
) | |
self.extraction_model.to(self.device) | |
logger.info("Models loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
def extract_metadata( | |
self, | |
text: str, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata from model card text. | |
Args: | |
text: The model card text | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
# Split text into sections | |
sections = self._split_into_sections(text) | |
# Classify sections | |
classified_sections = self._classify_sections(sections) | |
# Extract metadata from each section | |
metadata = {} | |
for section_type, section_text in classified_sections.items(): | |
if fields and section_type not in fields: | |
continue | |
extracted = self._extract_from_section(section_type, section_text) | |
if extracted: | |
metadata[section_type] = extracted | |
return metadata | |
def _split_into_sections(self, text: str) -> List[Dict[str, str]]: | |
""" | |
Split the model card text into sections. | |
Args: | |
text: The model card text | |
Returns: | |
List of sections with title and content | |
""" | |
# Simple section splitting based on headers | |
# In a real implementation, this would be more sophisticated | |
sections = [] | |
# Match markdown headers (# Header, ## Header, etc.) | |
header_pattern = r"(?:^|\n)(#+)\s+(.*?)(?:\n|$)" | |
# Find all headers | |
headers = list(re.finditer(header_pattern, text)) | |
for i, match in enumerate(headers): | |
header_level = len(match.group(1)) | |
header_text = match.group(2).strip() | |
start = match.end() | |
# Find the end of the section (next header or end of text) | |
if i < len(headers) - 1: | |
end = headers[i + 1].start() | |
else: | |
end = len(text) | |
# Extract the section content | |
content = text[start:end].strip() | |
sections.append({ | |
"title": header_text, | |
"level": header_level, | |
"content": content, | |
}) | |
# If no sections were found, treat the entire text as one section | |
if not sections: | |
sections.append({ | |
"title": "Main", | |
"level": 1, | |
"content": text.strip(), | |
}) | |
return sections | |
def _classify_sections(self, sections: List[Dict[str, str]]) -> Dict[str, str]: | |
""" | |
Classify sections into metadata categories. | |
Args: | |
sections: List of sections with title and content | |
Returns: | |
Dictionary mapping section types to section content | |
""" | |
classified = {} | |
# Map common section titles to metadata fields | |
title_mappings = { | |
"model description": "description", | |
"description": "description", | |
"model details": "model_parameters", | |
"model architecture": "model_parameters", | |
"parameters": "model_parameters", | |
"training data": "datasets", | |
"dataset": "datasets", | |
"datasets": "datasets", | |
"training": "training_procedure", | |
"evaluation": "evaluation_results", | |
"results": "evaluation_results", | |
"performance": "evaluation_results", | |
"metrics": "evaluation_results", | |
"limitations": "limitations", | |
"biases": "ethical_considerations", | |
"bias": "ethical_considerations", | |
"ethical considerations": "ethical_considerations", | |
"ethics": "ethical_considerations", | |
"risks": "ethical_considerations", | |
"license": "license", | |
"citation": "citation", | |
"references": "citation", | |
} | |
for section in sections: | |
title = section["title"].lower() | |
content = section["content"] | |
# Check for direct title matches | |
matched = False | |
for key, value in title_mappings.items(): | |
if key in title: | |
if value not in classified: | |
classified[value] = content | |
else: | |
classified[value] += "\n\n" + content | |
matched = True | |
break | |
# If no match by title, use the classifier | |
if not matched and self.section_classifier and len(content.split()) > 5: | |
try: | |
# This is a placeholder for actual classification | |
# In a real implementation, this would use the fine-tuned classifier | |
section_type = self._classify_text(content) | |
if section_type and section_type not in classified: | |
classified[section_type] = content | |
elif section_type: | |
classified[section_type] += "\n\n" + content | |
except Exception as e: | |
logger.error(f"Error classifying section: {e}") | |
return classified | |
def _classify_text(self, text: str) -> Optional[str]: | |
""" | |
Classify text into a metadata category. | |
Args: | |
text: The text to classify | |
Returns: | |
Metadata category or None if classification fails | |
""" | |
# This is a placeholder for actual classification | |
# In a real implementation, this would use the fine-tuned classifier | |
# Simple keyword-based classification for demonstration | |
keywords = { | |
"description": ["is a", "this model", "based on", "pretrained"], | |
"model_parameters": ["parameters", "layers", "hidden", "dimension", "architecture"], | |
"datasets": ["dataset", "corpus", "trained on", "fine-tuned on"], | |
"evaluation_results": ["accuracy", "f1", "precision", "recall", "performance"], | |
"limitations": ["limitation", "limited", "does not", "cannot", "fails to"], | |
"ethical_considerations": ["bias", "ethical", "fairness", "gender", "race"], | |
} | |
# Count keyword occurrences | |
counts = {category: 0 for category in keywords} | |
for category, words in keywords.items(): | |
for word in words: | |
counts[category] += len(re.findall(r'\b' + re.escape(word) + r'\b', text.lower())) | |
# Return the category with the most keyword matches | |
if counts: | |
max_category = max(counts.items(), key=lambda x: x[1]) | |
if max_category[1] > 0: | |
return max_category[0] | |
return None | |
def _extract_from_section(self, section_type: str, text: str) -> Any: | |
""" | |
Extract structured metadata from a section. | |
Args: | |
section_type: The type of section | |
text: The section text | |
Returns: | |
Extracted metadata | |
""" | |
# This is a placeholder for actual extraction | |
# In a real implementation, this would use the fine-tuned extraction model | |
if section_type == "description": | |
# Simply return the text for description | |
return text.strip() | |
elif section_type == "model_parameters": | |
# Extract model parameters using regex | |
params = {} | |
# Extract architecture | |
arch_match = re.search(r'(?:architecture|model type|based on)[:\s]+([A-Za-z0-9\-]+)', text, re.IGNORECASE) | |
if arch_match: | |
params["architecture"] = arch_match.group(1).strip() | |
# Extract parameter count | |
param_match = re.search(r'(\d+(?:\.\d+)?)\s*(?:B|M|K)?\s*(?:billion|million|thousand)?\s*parameters', text, re.IGNORECASE) | |
if param_match: | |
params["parameter_count"] = param_match.group(1).strip() | |
return params | |
elif section_type == "datasets": | |
# Extract dataset names | |
datasets = [] | |
dataset_patterns = [ | |
r'trained on\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)?', | |
r'dataset[:\s]+([A-Za-z0-9\-\s]+)', | |
r'using\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)', | |
] | |
for pattern in dataset_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
dataset = match.group(1).strip() | |
if dataset and dataset.lower() not in ["this", "these", "those"]: | |
datasets.append(dataset) | |
return list(set(datasets)) | |
elif section_type == "evaluation_results": | |
# Extract evaluation metrics | |
results = {} | |
# Extract accuracy | |
acc_match = re.search(r'accuracy[:\s]+(\d+(?:\.\d+)?)\s*%?', text, re.IGNORECASE) | |
if acc_match: | |
results["accuracy"] = float(acc_match.group(1)) | |
# Extract F1 score | |
f1_match = re.search(r'f1(?:\s*[\-_]?score)?[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) | |
if f1_match: | |
results["f1"] = float(f1_match.group(1)) | |
# Extract precision | |
prec_match = re.search(r'precision[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) | |
if prec_match: | |
results["precision"] = float(prec_match.group(1)) | |
# Extract recall | |
recall_match = re.search(r'recall[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE) | |
if recall_match: | |
results["recall"] = float(recall_match.group(1)) | |
return results | |
elif section_type == "limitations": | |
# Simply return the text for limitations | |
return text.strip() | |
elif section_type == "ethical_considerations": | |
# Simply return the text for ethical considerations | |
return text.strip() | |
elif section_type == "license": | |
# Extract license information | |
license_match = re.search(r'(?:license|licensing)[:\s]+([A-Za-z0-9\-\s]+)', text, re.IGNORECASE) | |
if license_match: | |
return license_match.group(1).strip() | |
return text.strip() | |
elif section_type == "citation": | |
# Simply return the text for citation | |
return text.strip() | |
# Default case | |
return text.strip() | |
class InferenceModelServer: | |
""" | |
Server for the inference model. | |
This class provides a server for the inference model that can be deployed | |
as a standalone service with a REST API. | |
""" | |
def __init__( | |
self, | |
model_name: str = "distilbert-base-uncased", | |
device: str = "cpu", | |
cache_dir: Optional[str] = None, | |
): | |
""" | |
Initialize the inference model server. | |
Args: | |
model_name: Name or path of the pre-trained model | |
device: Device to run the model on ('cpu' or 'cuda') | |
cache_dir: Directory to cache models | |
""" | |
self.extractor = ModelCardExtractor( | |
model_name=model_name, | |
device=device, | |
cache_dir=cache_dir, | |
) | |
def extract_metadata( | |
self, | |
text: str, | |
structured_metadata: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata from model card text. | |
Args: | |
text: The model card text | |
structured_metadata: Optional structured metadata to provide context | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
try: | |
# Extract metadata using the extractor | |
metadata = self.extractor.extract_metadata(text, fields) | |
# Enhance with structured metadata if provided | |
if structured_metadata: | |
# Use structured metadata for fields not extracted | |
for key, value in structured_metadata.items(): | |
if key not in metadata or not metadata[key]: | |
metadata[key] = value | |
return {"metadata": metadata, "success": True} | |
except Exception as e: | |
logger.error(f"Error extracting metadata: {e}") | |
return {"metadata": {}, "success": False, "error": str(e)} | |
def create_app(model_name: str = "distilbert-base-uncased", device: str = "cpu"): | |
""" | |
Create a Flask app for the inference model server. | |
Args: | |
model_name: Name or path of the pre-trained model | |
device: Device to run the model on ('cpu' or 'cuda') | |
Returns: | |
Flask app | |
""" | |
from flask import Flask, request, jsonify | |
app = Flask(__name__) | |
server = InferenceModelServer(model_name=model_name, device=device) | |
def extract(): | |
data = request.json | |
text = data.get("text", "") | |
structured_metadata = data.get("structured_metadata", {}) | |
fields = data.get("fields", []) | |
result = server.extract_metadata(text, structured_metadata, fields) | |
return jsonify(result) | |
def health(): | |
return jsonify({"status": "healthy"}) | |
return app | |
def main(): | |
"""Main entry point for the inference model server.""" | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="Start the inference model server for AIBOM metadata extraction." | |
) | |
parser.add_argument( | |
"--model", | |
help="Name or path of the pre-trained model", | |
default="distilbert-base-uncased", | |
) | |
parser.add_argument( | |
"--device", | |
help="Device to run the model on ('cpu' or 'cuda')", | |
choices=["cpu", "cuda"], | |
default="cpu", | |
) | |
parser.add_argument( | |
"--host", | |
help="Host to bind the server to", | |
default="0.0.0.0", | |
) | |
parser.add_argument( | |
"--port", | |
help="Port to bind the server to", | |
type=int, | |
default=5000, | |
) | |
parser.add_argument( | |
"--debug", | |
help="Enable debug mode", | |
action="store_true", | |
) | |
args = parser.parse_args() | |
# Create and run the app | |
app = create_app(model_name=args.model, device=args.device) | |
app.run(host=args.host, port=args.port, debug=args.debug) | |
if __name__ == "__main__": | |
main() | |