File size: 676 Bytes
8474315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from src.config.config import setup_logging
from src.pipeline import NYCDataLoader, VanillaLSTM, Transformer, VAE, AnomalyDetector


def inference():

    seq_length = 48

    setup_logging()

    # Load the preprocessed data
    data_loader = NYCDataLoader(batch_size=32)
    train_loader, _, test_loader = data_loader.load_data()

    # Get the true anomalies
    true_anomalies = data_loader.get_true_anomalies()

    # Initialize the AnomalyDetector
    detector = AnomalyDetector()

    # Load the trained models
    detector.load_data(test_loader=test_loader)
    detector.load_trained_model("transformer_model.pth", model_type="transformer")