aibom-generator / src /aibom_generator /inference_model.py
a1c00l's picture
Upload 9 files
8819832 verified
raw
history blame
18.5 kB
"""
Inference model implementation for metadata extraction from model cards.
This module provides a fine-tuned model for extracting structured metadata
from unstructured text in Hugging Face model cards.
"""
import json
import logging
import os
import re
import torch
from typing import Dict, List, Optional, Any, Union
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer
logger = logging.getLogger(__name__)
class ModelCardExtractor:
"""
Fine-tuned model for extracting metadata from model card text.
"""
def __init__(
self,
model_name: str = "distilbert-base-uncased",
device: str = "cpu",
max_length: int = 512,
cache_dir: Optional[str] = None,
):
"""
Initialize the model card extractor.
Args:
model_name: Name or path of the pre-trained model
device: Device to run the model on ('cpu' or 'cuda')
max_length: Maximum sequence length for tokenization
cache_dir: Directory to cache models
"""
self.model_name = model_name
self.device = device
self.max_length = max_length
self.cache_dir = cache_dir
# Load tokenizer and model
self.tokenizer = None
self.model = None
# Initialize extraction pipelines
self.section_classifier = None
self.metadata_extractor = None
# Load models
self._load_models()
def _load_models(self):
"""Load the required models for extraction."""
try:
# Load section classifier
logger.info(f"Loading section classifier model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
self.model.to(self.device)
# Create section classification pipeline
self.section_classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if self.device == "cuda" else -1,
)
# For demonstration purposes, we'll use a T5-based model for extraction
# In a real implementation, this would be a fine-tuned model specific to the task
logger.info("Loading metadata extraction model")
extraction_model_name = "t5-small" # Placeholder for fine-tuned model
self.extraction_tokenizer = T5Tokenizer.from_pretrained(
extraction_model_name,
cache_dir=self.cache_dir,
)
self.extraction_model = AutoModelForSeq2SeqLM.from_pretrained(
extraction_model_name,
cache_dir=self.cache_dir,
)
self.extraction_model.to(self.device)
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
def extract_metadata(
self,
text: str,
fields: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Extract metadata from model card text.
Args:
text: The model card text
fields: Optional list of specific fields to extract
Returns:
Extracted metadata as a dictionary
"""
# Split text into sections
sections = self._split_into_sections(text)
# Classify sections
classified_sections = self._classify_sections(sections)
# Extract metadata from each section
metadata = {}
for section_type, section_text in classified_sections.items():
if fields and section_type not in fields:
continue
extracted = self._extract_from_section(section_type, section_text)
if extracted:
metadata[section_type] = extracted
return metadata
def _split_into_sections(self, text: str) -> List[Dict[str, str]]:
"""
Split the model card text into sections.
Args:
text: The model card text
Returns:
List of sections with title and content
"""
# Simple section splitting based on headers
# In a real implementation, this would be more sophisticated
sections = []
# Match markdown headers (# Header, ## Header, etc.)
header_pattern = r"(?:^|\n)(#+)\s+(.*?)(?:\n|$)"
# Find all headers
headers = list(re.finditer(header_pattern, text))
for i, match in enumerate(headers):
header_level = len(match.group(1))
header_text = match.group(2).strip()
start = match.end()
# Find the end of the section (next header or end of text)
if i < len(headers) - 1:
end = headers[i + 1].start()
else:
end = len(text)
# Extract the section content
content = text[start:end].strip()
sections.append({
"title": header_text,
"level": header_level,
"content": content,
})
# If no sections were found, treat the entire text as one section
if not sections:
sections.append({
"title": "Main",
"level": 1,
"content": text.strip(),
})
return sections
def _classify_sections(self, sections: List[Dict[str, str]]) -> Dict[str, str]:
"""
Classify sections into metadata categories.
Args:
sections: List of sections with title and content
Returns:
Dictionary mapping section types to section content
"""
classified = {}
# Map common section titles to metadata fields
title_mappings = {
"model description": "description",
"description": "description",
"model details": "model_parameters",
"model architecture": "model_parameters",
"parameters": "model_parameters",
"training data": "datasets",
"dataset": "datasets",
"datasets": "datasets",
"training": "training_procedure",
"evaluation": "evaluation_results",
"results": "evaluation_results",
"performance": "evaluation_results",
"metrics": "evaluation_results",
"limitations": "limitations",
"biases": "ethical_considerations",
"bias": "ethical_considerations",
"ethical considerations": "ethical_considerations",
"ethics": "ethical_considerations",
"risks": "ethical_considerations",
"license": "license",
"citation": "citation",
"references": "citation",
}
for section in sections:
title = section["title"].lower()
content = section["content"]
# Check for direct title matches
matched = False
for key, value in title_mappings.items():
if key in title:
if value not in classified:
classified[value] = content
else:
classified[value] += "\n\n" + content
matched = True
break
# If no match by title, use the classifier
if not matched and self.section_classifier and len(content.split()) > 5:
try:
# This is a placeholder for actual classification
# In a real implementation, this would use the fine-tuned classifier
section_type = self._classify_text(content)
if section_type and section_type not in classified:
classified[section_type] = content
elif section_type:
classified[section_type] += "\n\n" + content
except Exception as e:
logger.error(f"Error classifying section: {e}")
return classified
def _classify_text(self, text: str) -> Optional[str]:
"""
Classify text into a metadata category.
Args:
text: The text to classify
Returns:
Metadata category or None if classification fails
"""
# This is a placeholder for actual classification
# In a real implementation, this would use the fine-tuned classifier
# Simple keyword-based classification for demonstration
keywords = {
"description": ["is a", "this model", "based on", "pretrained"],
"model_parameters": ["parameters", "layers", "hidden", "dimension", "architecture"],
"datasets": ["dataset", "corpus", "trained on", "fine-tuned on"],
"evaluation_results": ["accuracy", "f1", "precision", "recall", "performance"],
"limitations": ["limitation", "limited", "does not", "cannot", "fails to"],
"ethical_considerations": ["bias", "ethical", "fairness", "gender", "race"],
}
# Count keyword occurrences
counts = {category: 0 for category in keywords}
for category, words in keywords.items():
for word in words:
counts[category] += len(re.findall(r'\b' + re.escape(word) + r'\b', text.lower()))
# Return the category with the most keyword matches
if counts:
max_category = max(counts.items(), key=lambda x: x[1])
if max_category[1] > 0:
return max_category[0]
return None
def _extract_from_section(self, section_type: str, text: str) -> Any:
"""
Extract structured metadata from a section.
Args:
section_type: The type of section
text: The section text
Returns:
Extracted metadata
"""
# This is a placeholder for actual extraction
# In a real implementation, this would use the fine-tuned extraction model
if section_type == "description":
# Simply return the text for description
return text.strip()
elif section_type == "model_parameters":
# Extract model parameters using regex
params = {}
# Extract architecture
arch_match = re.search(r'(?:architecture|model type|based on)[:\s]+([A-Za-z0-9\-]+)', text, re.IGNORECASE)
if arch_match:
params["architecture"] = arch_match.group(1).strip()
# Extract parameter count
param_match = re.search(r'(\d+(?:\.\d+)?)\s*(?:B|M|K)?\s*(?:billion|million|thousand)?\s*parameters', text, re.IGNORECASE)
if param_match:
params["parameter_count"] = param_match.group(1).strip()
return params
elif section_type == "datasets":
# Extract dataset names
datasets = []
dataset_patterns = [
r'trained on\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)?',
r'dataset[:\s]+([A-Za-z0-9\-\s]+)',
r'using\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)',
]
for pattern in dataset_patterns:
for match in re.finditer(pattern, text, re.IGNORECASE):
dataset = match.group(1).strip()
if dataset and dataset.lower() not in ["this", "these", "those"]:
datasets.append(dataset)
return list(set(datasets))
elif section_type == "evaluation_results":
# Extract evaluation metrics
results = {}
# Extract accuracy
acc_match = re.search(r'accuracy[:\s]+(\d+(?:\.\d+)?)\s*%?', text, re.IGNORECASE)
if acc_match:
results["accuracy"] = float(acc_match.group(1))
# Extract F1 score
f1_match = re.search(r'f1(?:\s*[\-_]?score)?[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE)
if f1_match:
results["f1"] = float(f1_match.group(1))
# Extract precision
prec_match = re.search(r'precision[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE)
if prec_match:
results["precision"] = float(prec_match.group(1))
# Extract recall
recall_match = re.search(r'recall[:\s]+(\d+(?:\.\d+)?)', text, re.IGNORECASE)
if recall_match:
results["recall"] = float(recall_match.group(1))
return results
elif section_type == "limitations":
# Simply return the text for limitations
return text.strip()
elif section_type == "ethical_considerations":
# Simply return the text for ethical considerations
return text.strip()
elif section_type == "license":
# Extract license information
license_match = re.search(r'(?:license|licensing)[:\s]+([A-Za-z0-9\-\s]+)', text, re.IGNORECASE)
if license_match:
return license_match.group(1).strip()
return text.strip()
elif section_type == "citation":
# Simply return the text for citation
return text.strip()
# Default case
return text.strip()
class InferenceModelServer:
"""
Server for the inference model.
This class provides a server for the inference model that can be deployed
as a standalone service with a REST API.
"""
def __init__(
self,
model_name: str = "distilbert-base-uncased",
device: str = "cpu",
cache_dir: Optional[str] = None,
):
"""
Initialize the inference model server.
Args:
model_name: Name or path of the pre-trained model
device: Device to run the model on ('cpu' or 'cuda')
cache_dir: Directory to cache models
"""
self.extractor = ModelCardExtractor(
model_name=model_name,
device=device,
cache_dir=cache_dir,
)
def extract_metadata(
self,
text: str,
structured_metadata: Optional[Dict[str, Any]] = None,
fields: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
Extract metadata from model card text.
Args:
text: The model card text
structured_metadata: Optional structured metadata to provide context
fields: Optional list of specific fields to extract
Returns:
Extracted metadata as a dictionary
"""
try:
# Extract metadata using the extractor
metadata = self.extractor.extract_metadata(text, fields)
# Enhance with structured metadata if provided
if structured_metadata:
# Use structured metadata for fields not extracted
for key, value in structured_metadata.items():
if key not in metadata or not metadata[key]:
metadata[key] = value
return {"metadata": metadata, "success": True}
except Exception as e:
logger.error(f"Error extracting metadata: {e}")
return {"metadata": {}, "success": False, "error": str(e)}
def create_app(model_name: str = "distilbert-base-uncased", device: str = "cpu"):
"""
Create a Flask app for the inference model server.
Args:
model_name: Name or path of the pre-trained model
device: Device to run the model on ('cpu' or 'cuda')
Returns:
Flask app
"""
from flask import Flask, request, jsonify
app = Flask(__name__)
server = InferenceModelServer(model_name=model_name, device=device)
@app.route("/extract", methods=["POST"])
def extract():
data = request.json
text = data.get("text", "")
structured_metadata = data.get("structured_metadata", {})
fields = data.get("fields", [])
result = server.extract_metadata(text, structured_metadata, fields)
return jsonify(result)
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "healthy"})
return app
def main():
"""Main entry point for the inference model server."""
import argparse
parser = argparse.ArgumentParser(
description="Start the inference model server for AIBOM metadata extraction."
)
parser.add_argument(
"--model",
help="Name or path of the pre-trained model",
default="distilbert-base-uncased",
)
parser.add_argument(
"--device",
help="Device to run the model on ('cpu' or 'cuda')",
choices=["cpu", "cuda"],
default="cpu",
)
parser.add_argument(
"--host",
help="Host to bind the server to",
default="0.0.0.0",
)
parser.add_argument(
"--port",
help="Port to bind the server to",
type=int,
default=5000,
)
parser.add_argument(
"--debug",
help="Enable debug mode",
action="store_true",
)
args = parser.parse_args()
# Create and run the app
app = create_app(model_name=args.model, device=args.device)
app.run(host=args.host, port=args.port, debug=args.debug)
if __name__ == "__main__":
main()