mcp-deepfake-forensics / agents /ensemble_team.py
LPX55's picture
hot-fix: memory (#2)
b26b103 verified
import logging
import time
import torch
import psutil # Ensure psutil is imported here as well
import GPUtil
from datetime import datetime, timedelta
import gc # Import garbage collector
logger = logging.getLogger(__name__)
class EnsembleMonitorAgent:
def __init__(self):
logger.info("Initializing EnsembleMonitorAgent.")
self.performance_metrics = {}
self.alerts = []
def monitor_prediction(self, model_id, prediction_label, confidence_score, inference_time):
logger.info(f"Monitoring prediction for model '{model_id}'. Label: {prediction_label}, Confidence: {confidence_score:.2f}, Time: {inference_time:.4f}s")
if model_id not in self.performance_metrics:
self.performance_metrics[model_id] = {
"total_predictions": 0,
"correct_predictions": 0, # This would require ground truth, which we don't have here.
"total_confidence": 0.0,
"total_inference_time": 0.0
}
metrics = self.performance_metrics[model_id]
metrics["total_predictions"] += 1
metrics["total_confidence"] += confidence_score
metrics["total_inference_time"] += inference_time
# Example alert: model taking too long
if inference_time > 5.0: # Threshold for slow inference
alert_msg = f"ALERT: Model '{model_id}' inference time exceeded 5.0s: {inference_time:.4f}s"
self.alerts.append(alert_msg)
logger.warning(alert_msg)
# Example alert: low confidence
if confidence_score < 0.5: # Threshold for low confidence
alert_msg = f"ALERT: Model '{model_id}' returned low confidence: {confidence_score:.2f}"
self.alerts.append(alert_msg)
logger.warning(alert_msg)
logger.info(f"Updated metrics for '{model_id}': {metrics}")
def get_performance_summary(self):
logger.info("Generating performance summary for all models.")
summary = {}
for model_id, metrics in self.performance_metrics.items():
avg_confidence = metrics["total_confidence"] / metrics["total_predictions"] if metrics["total_predictions"] > 0 else 0
avg_inference_time = metrics["total_inference_time"] / metrics["total_predictions"] if metrics["total_predictions"] > 0 else 0
summary[model_id] = {
"avg_confidence": avg_confidence,
"avg_inference_time": avg_inference_time,
"total_predictions": metrics["total_predictions"]
}
logger.info(f"Performance summary: {summary}")
return summary
class WeightOptimizationAgent:
def __init__(self, weight_manager):
logger.info("Initializing WeightOptimizationAgent.")
self.weight_manager = weight_manager
self.prediction_history = []
self.performance_window = timedelta(hours=24) # Evaluate performance over last 24 hours
def analyze_performance(self, final_prediction, ground_truth=None):
logger.info(f"Analyzing performance. Final prediction: {final_prediction}, Ground truth: {ground_truth}")
timestamp = datetime.now()
self.prediction_history.append({
"timestamp": timestamp,
"final_prediction": final_prediction,
"ground_truth": ground_truth # Ground truth is often not available in real-time
})
# Keep history windowed
self.prediction_history = [p for p in self.prediction_history if timestamp - p["timestamp"] < self.performance_window]
logger.info(f"Prediction history length: {len(self.prediction_history)}")
# In a real scenario, this would involve a more complex optimization logic
# For now, it just logs the history length.
class SystemHealthAgent:
def __init__(self):
logger.info("Initializing SystemHealthAgent.")
self.health_metrics = {
"cpu_percent": 0,
"memory_usage": {"total": 0, "available": 0, "percent": 0},
"gpu_utilization": []
}
def monitor_system_health(self):
logger.info("Monitoring system health...")
self.health_metrics["cpu_percent"] = psutil.cpu_percent(interval=1)
mem = psutil.virtual_memory()
self.health_metrics["memory_usage"] = {
"total": mem.total,
"available": mem.available,
"percent": mem.percent
}
# Holy moly, been at 99% for hours whoops
if mem.percent > 90:
logger.warning(f"CRITICAL: System memory usage is at {mem.percent}%. Attempting to clear memory cache...")
gc.collect()
logger.info("Garbage collection triggered. Re-checking memory usage...")
mem_after_gc = psutil.virtual_memory()
self.health_metrics["memory_usage_after_gc"] = {
"total": mem_after_gc.total,
"available": mem_after_gc.available,
"percent": mem_after_gc.percent
}
logger.info(f"Memory usage after GC: {mem_after_gc.percent}%")
gpu_info = []
try:
gpus = GPUtil.getGPUs()
for gpu in gpus:
gpu_info.append({
"id": gpu.id,
"name": gpu.name,
"load": gpu.load,
"memoryUtil": gpu.memoryUtil,
"memoryTotal": gpu.memoryTotal,
"memoryUsed": gpu.memoryUsed
})
except Exception as e:
logger.warning(f"Could not retrieve GPU information: {e}")
gpu_info.append({"error": str(e)})
self.health_metrics["gpu_utilization"] = gpu_info
logger.info(f"System health metrics: CPU: {self.health_metrics['cpu_percent']}%, Memory: {self.health_metrics['memory_usage']['percent']}%, GPU: {gpu_info}")