Spaces:
Running
Running
""" | |
Inference model integration for extracting metadata from unstructured text. | |
""" | |
import json | |
import logging | |
import re | |
import requests | |
from typing import Dict, List, Optional, Any, Union | |
logger = logging.getLogger(__name__) | |
class InferenceModelClient: | |
""" | |
Client for interacting with the inference model service to extract | |
metadata from unstructured text in model cards. | |
""" | |
def __init__( | |
self, | |
inference_url: str, | |
timeout: int = 30, | |
max_retries: int = 3, | |
): | |
""" | |
Initialize the inference model client. | |
Args: | |
inference_url: URL of the inference model service | |
timeout: Request timeout in seconds | |
max_retries: Maximum number of retries for failed requests | |
""" | |
self.inference_url = inference_url | |
self.timeout = timeout | |
self.max_retries = max_retries | |
def extract_metadata( | |
self, | |
model_card_text: str, | |
structured_metadata: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata from unstructured text using the inference model. | |
Args: | |
model_card_text: The text content of the model card | |
structured_metadata: Optional structured metadata to provide context | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
if not self.inference_url: | |
logger.warning("No inference model URL provided, skipping extraction") | |
return {} | |
# Prepare the request payload | |
payload = { | |
"text": model_card_text, | |
"structured_metadata": structured_metadata or {}, | |
"fields": fields or [], | |
} | |
# Make the request to the inference model | |
try: | |
response = self._make_request(payload) | |
return response.get("metadata", {}) | |
except Exception as e: | |
logger.error(f"Error extracting metadata with inference model: {e}") | |
return {} | |
def _make_request(self, payload: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Make a request to the inference model service. | |
Args: | |
payload: Request payload | |
Returns: | |
Response from the inference model | |
Raises: | |
Exception: If the request fails after max_retries | |
""" | |
headers = {"Content-Type": "application/json"} | |
for attempt in range(self.max_retries): | |
try: | |
response = requests.post( | |
self.inference_url, | |
headers=headers, | |
json=payload, | |
timeout=self.timeout, | |
) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
logger.warning(f"Request failed (attempt {attempt+1}/{self.max_retries}): {e}") | |
if attempt == self.max_retries - 1: | |
raise | |
# This should never be reached due to the raise in the loop | |
raise Exception("Failed to make request to inference model") | |
class FallbackExtractor: | |
""" | |
Fallback extractor for extracting metadata using regex and heuristics | |
when the inference model is not available or fails. | |
""" | |
def extract_metadata( | |
self, | |
model_card_text: str, | |
structured_metadata: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata using regex and heuristics. | |
Args: | |
model_card_text: The text content of the model card | |
structured_metadata: Optional structured metadata to provide context | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
metadata = {} | |
# Extract model parameters | |
metadata.update(self._extract_model_parameters(model_card_text)) | |
# Extract limitations and ethical considerations | |
metadata.update(self._extract_considerations(model_card_text)) | |
# Extract datasets | |
metadata.update(self._extract_datasets(model_card_text)) | |
# Extract evaluation results | |
metadata.update(self._extract_evaluation_results(model_card_text)) | |
return metadata | |
def _extract_model_parameters(self, text: str) -> Dict[str, Any]: | |
"""Extract model parameters from text.""" | |
params = {} | |
# Extract model type/architecture | |
architecture_patterns = [ | |
r"(?:model|architecture)(?:\s+type)?(?:\s*:\s*|\s+is\s+)([A-Za-z0-9\-]+)", | |
r"based\s+on\s+(?:the\s+)?([A-Za-z0-9\-]+)(?:\s+architecture)?", | |
] | |
for pattern in architecture_patterns: | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
params["architecture"] = match.group(1).strip() | |
break | |
# Extract number of parameters | |
param_patterns = [ | |
r"(\d+(?:\.\d+)?)\s*(?:B|M|K)?\s*(?:billion|million|thousand)?\s*parameters", | |
r"parameters\s*:\s*(\d+(?:\.\d+)?)\s*(?:B|M|K)?", | |
] | |
for pattern in param_patterns: | |
match = re.search(pattern, text, re.IGNORECASE) | |
if match: | |
params["parameters"] = match.group(1).strip() | |
# TODO: Normalize to a standard unit | |
break | |
return {"model_parameters": params} if params else {} | |
def _extract_considerations(self, text: str) -> Dict[str, Any]: | |
"""Extract limitations and ethical considerations from text.""" | |
considerations = {} | |
# Extract limitations | |
limitations_section = self._extract_section(text, ["limitations", "limits", "shortcomings"]) | |
if limitations_section: | |
considerations["limitations"] = limitations_section | |
# Extract ethical considerations | |
ethics_section = self._extract_section( | |
text, ["ethical considerations", "ethics", "bias", "fairness", "risks"] | |
) | |
if ethics_section: | |
considerations["ethical_considerations"] = ethics_section | |
return {"considerations": considerations} if considerations else {} | |
def _extract_datasets(self, text: str) -> Dict[str, Any]: | |
"""Extract dataset information from text.""" | |
datasets = [] | |
# Extract dataset mentions | |
dataset_patterns = [ | |
r"trained\s+on\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)?", | |
r"dataset(?:\s*:\s*|\s+is\s+)([A-Za-z0-9\-\s]+)", | |
r"using\s+(?:the\s+)?([A-Za-z0-9\-\s]+)(?:\s+dataset)", | |
] | |
for pattern in dataset_patterns: | |
for match in re.finditer(pattern, text, re.IGNORECASE): | |
dataset = match.group(1).strip() | |
if dataset and dataset.lower() not in ["this", "these", "those"]: | |
datasets.append(dataset) | |
return {"datasets": list(set(datasets))} if datasets else {} | |
def _extract_evaluation_results(self, text: str) -> Dict[str, Any]: | |
"""Extract evaluation results from text.""" | |
results = {} | |
# Extract accuracy | |
accuracy_match = re.search( | |
r"accuracy(?:\s*:\s*|\s+of\s+|\s+is\s+)(\d+(?:\.\d+)?)\s*%?", | |
text, | |
re.IGNORECASE, | |
) | |
if accuracy_match: | |
results["accuracy"] = float(accuracy_match.group(1)) | |
# Extract F1 score | |
f1_match = re.search( | |
r"f1(?:\s*[\-_]?score)?(?:\s*:\s*|\s+of\s+|\s+is\s+)(\d+(?:\.\d+)?)", | |
text, | |
re.IGNORECASE, | |
) | |
if f1_match: | |
results["f1"] = float(f1_match.group(1)) | |
# Extract precision | |
precision_match = re.search( | |
r"precision(?:\s*:\s*|\s+of\s+|\s+is\s+)(\d+(?:\.\d+)?)", | |
text, | |
re.IGNORECASE, | |
) | |
if precision_match: | |
results["precision"] = float(precision_match.group(1)) | |
# Extract recall | |
recall_match = re.search( | |
r"recall(?:\s*:\s*|\s+of\s+|\s+is\s+)(\d+(?:\.\d+)?)", | |
text, | |
re.IGNORECASE, | |
) | |
if recall_match: | |
results["recall"] = float(recall_match.group(1)) | |
return {"evaluation_results": results} if results else {} | |
def _extract_section(self, text: str, section_names: List[str]) -> Optional[str]: | |
""" | |
Extract a section from the text based on section names. | |
Args: | |
text: The text to extract from | |
section_names: Possible names for the section | |
Returns: | |
The extracted section text, or None if not found | |
""" | |
# Create pattern to match section headers | |
header_pattern = r"(?:^|\n)(?:#+\s*|[0-9]+\.\s*|[A-Z\s]+:\s*)(?:{})(?:\s*:)?(?:\s*\n|\s*$)".format( | |
"|".join(section_names) | |
) | |
# Find all section headers | |
headers = list(re.finditer(header_pattern, text, re.IGNORECASE)) | |
for i, match in enumerate(headers): | |
start = match.end() | |
# Find the end of the section (next header or end of text) | |
if i < len(headers) - 1: | |
end = headers[i + 1].start() | |
else: | |
end = len(text) | |
# Extract the section content | |
section = text[start:end].strip() | |
if section: | |
return section | |
return None | |
class MetadataExtractor: | |
""" | |
Metadata extractor that combines inference model and fallback extraction. | |
""" | |
def __init__( | |
self, | |
inference_url: Optional[str] = None, | |
use_inference: bool = True, | |
): | |
""" | |
Initialize the metadata extractor. | |
Args: | |
inference_url: URL of the inference model service | |
use_inference: Whether to use the inference model | |
""" | |
self.use_inference = use_inference and inference_url is not None | |
self.inference_client = InferenceModelClient(inference_url) if self.use_inference else None | |
self.fallback_extractor = FallbackExtractor() | |
def extract_metadata( | |
self, | |
model_card_text: str, | |
structured_metadata: Optional[Dict[str, Any]] = None, | |
fields: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
""" | |
Extract metadata from model card text. | |
Args: | |
model_card_text: The text content of the model card | |
structured_metadata: Optional structured metadata to provide context | |
fields: Optional list of specific fields to extract | |
Returns: | |
Extracted metadata as a dictionary | |
""" | |
metadata = {} | |
# Try inference model first if enabled | |
if self.use_inference and self.inference_client: | |
try: | |
inference_metadata = self.inference_client.extract_metadata( | |
model_card_text, structured_metadata, fields | |
) | |
metadata.update(inference_metadata) | |
except Exception as e: | |
logger.error(f"Inference model extraction failed: {e}") | |
# Use fallback extractor for missing fields or if inference failed | |
if not metadata or (fields and not all(field in metadata for field in fields)): | |
missing_fields = fields if fields else None | |
if fields: | |
missing_fields = [field for field in fields if field not in metadata] | |
fallback_metadata = self.fallback_extractor.extract_metadata( | |
model_card_text, structured_metadata, missing_fields | |
) | |
# Only update with fallback data for fields that weren't extracted by inference | |
for key, value in fallback_metadata.items(): | |
if key not in metadata or not metadata[key]: | |
metadata[key] = value | |
return metadata | |