Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import TensorDataset, DataLoader
|
12 |
+
from tqdm import tqdm
|
13 |
+
from flask import Flask, render_template, request, jsonify, send_file
|
14 |
+
from flask_socketio import SocketIO, emit
|
15 |
+
import tempfile
|
16 |
+
import threading
|
17 |
+
from pathlib import Path
|
18 |
+
from werkzeug.utils import secure_filename
|
19 |
+
|
20 |
+
app = Flask(__name__)
|
21 |
+
app.config['SECRET_KEY'] = 'your-secret-key-here'
|
22 |
+
app.config['UPLOAD_FOLDER'] = 'uploads'
|
23 |
+
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max file size
|
24 |
+
socketio = SocketIO(app, cors_allowed_origins="*")
|
25 |
+
|
26 |
+
# Ensure upload directory exists
|
27 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
28 |
+
|
29 |
+
# Global variables for progress tracking
|
30 |
+
current_progress = {'step': 'idle', 'progress': 0, 'details': ''}
|
31 |
+
|
32 |
+
########################################
|
33 |
+
# MODEL DEFINITION #
|
34 |
+
########################################
|
35 |
+
|
36 |
+
class LSTMWithAttentionWithResid(nn.Module):
|
37 |
+
def __init__(self, in_dim, hidden_dim, forecast_horizon, n_layers=10, dropout=0.2):
|
38 |
+
super(LSTMWithAttentionWithResid, self).__init__()
|
39 |
+
self.hidden_dim = hidden_dim
|
40 |
+
self.forecast_horizon = forecast_horizon
|
41 |
+
|
42 |
+
# Embedding layer
|
43 |
+
self.embedding = nn.Linear(in_dim, hidden_dim)
|
44 |
+
|
45 |
+
# LSTM layers
|
46 |
+
self.lstm = nn.LSTM(
|
47 |
+
hidden_dim, hidden_dim, num_layers=n_layers, dropout=dropout, batch_first=True
|
48 |
+
)
|
49 |
+
|
50 |
+
# Layer normalization after residual connection
|
51 |
+
self.layer_norm = nn.LayerNorm(hidden_dim)
|
52 |
+
|
53 |
+
# Attention mechanism
|
54 |
+
self.attention = nn.Linear(hidden_dim, hidden_dim)
|
55 |
+
self.context_vector = nn.Linear(hidden_dim, 1, bias=False) # Linear layer for scoring
|
56 |
+
|
57 |
+
# Fully connected layer to map attention context to output
|
58 |
+
self.fc = nn.Linear(hidden_dim, forecast_horizon * 2)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
# x: [batch_size, seq_len, in_dim]
|
62 |
+
|
63 |
+
# Pass through embedding layer
|
64 |
+
x_embed = self.embedding(x) # [batch_size, seq_len, hidden_dim]
|
65 |
+
|
66 |
+
# Pass through LSTM
|
67 |
+
lstm_output, (hidden, cell) = self.lstm(x_embed) # [batch_size, seq_len, hidden_dim]
|
68 |
+
|
69 |
+
# Add residual connection (out-of-place)
|
70 |
+
lstm_output = lstm_output + x_embed # [batch_size, seq_len, hidden_dim]
|
71 |
+
|
72 |
+
# Apply layer normalization
|
73 |
+
lstm_output = self.layer_norm(lstm_output) # [batch_size, seq_len, hidden_dim]
|
74 |
+
|
75 |
+
# Compute attention scores
|
76 |
+
attention_weights = torch.tanh(self.attention(lstm_output)) # [batch_size, seq_len, hidden_dim]
|
77 |
+
attention_scores = self.context_vector(attention_weights).squeeze(-1) # [batch_size, seq_len]
|
78 |
+
|
79 |
+
# Apply softmax to normalize scores
|
80 |
+
attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len]
|
81 |
+
|
82 |
+
# Compute the context vector as a weighted sum of LSTM outputs
|
83 |
+
context_vector = torch.bmm(
|
84 |
+
attention_weights.unsqueeze(1), lstm_output
|
85 |
+
) # [batch_size, 1, hidden_dim]
|
86 |
+
context_vector = context_vector.squeeze(1) # [batch_size, hidden_dim]
|
87 |
+
|
88 |
+
# Pass context vector through fully connected layer for forecasting
|
89 |
+
output = self.fc(context_vector) # [batch_size, forecast_horizon * 2]
|
90 |
+
|
91 |
+
# Reshape output to match the expected shape
|
92 |
+
output = output.view(-1, self.forecast_horizon, 2) # [batch_size, forecast_horizon, 2]
|
93 |
+
|
94 |
+
return output
|
95 |
+
|
96 |
+
########################################
|
97 |
+
# UTILITY FUNCTIONS #
|
98 |
+
########################################
|
99 |
+
|
100 |
+
def update_progress(step, progress, details=""):
|
101 |
+
"""Update global progress state"""
|
102 |
+
global current_progress
|
103 |
+
current_progress = {
|
104 |
+
'step': step,
|
105 |
+
'progress': progress,
|
106 |
+
'details': details
|
107 |
+
}
|
108 |
+
socketio.emit('progress_update', current_progress)
|
109 |
+
|
110 |
+
def create_sequences_grouped_by_segment_lat_long_veloc(df_scaled, seq_len=12, forecast_horizon=1, features_to_scale=None):
|
111 |
+
"""
|
112 |
+
For each segment, creates overlapping sequences of length seq_len.
|
113 |
+
Returns:
|
114 |
+
- Xs: input sequences,
|
115 |
+
- ys: target outputs (future latitude and longitude velocities),
|
116 |
+
- segments: corresponding segment IDs,
|
117 |
+
- last_positions: last known positions from each sequence.
|
118 |
+
"""
|
119 |
+
update_progress('Creating sequences', 10, f'Processing {len(df_scaled)} data points...')
|
120 |
+
|
121 |
+
Xs, ys, segments, last_positions = [], [], [], []
|
122 |
+
|
123 |
+
if features_to_scale is None:
|
124 |
+
# CRITICAL: Match YOUR EXACT inference logic (segment first, then removed)
|
125 |
+
features_to_scale = [
|
126 |
+
"segment", # Index 0 - will be removed before model
|
127 |
+
"latitude_velocity_km", # Index 1 -> 0 after segment removal
|
128 |
+
"longitude_velocity_km", # Index 2 -> 1 after segment removal
|
129 |
+
"latitude_degrees", # Index 3 -> 2 after segment removal
|
130 |
+
"longitude_degrees", # Index 4 -> 3 after segment removal
|
131 |
+
"time_difference_hours", # Index 5 -> 4 after segment removal
|
132 |
+
"time_scalar" # Index 6 -> 5 after segment removal
|
133 |
+
]
|
134 |
+
|
135 |
+
# Verify all required features exist
|
136 |
+
missing_features = [f for f in features_to_scale if f not in df_scaled.columns]
|
137 |
+
if missing_features:
|
138 |
+
raise ValueError(f"Missing required features: {missing_features}")
|
139 |
+
|
140 |
+
grouped = df_scaled.groupby('segment')
|
141 |
+
total_segments = len(grouped)
|
142 |
+
|
143 |
+
for i, (segment_id, group) in enumerate(grouped):
|
144 |
+
group = group.reset_index(drop=True)
|
145 |
+
L = len(group)
|
146 |
+
|
147 |
+
# Progress update
|
148 |
+
if i % max(1, total_segments // 20) == 0:
|
149 |
+
progress = 10 + (i / total_segments) * 30 # 10-40% range
|
150 |
+
update_progress('Creating sequences', progress,
|
151 |
+
f'Processing segment {i+1}/{total_segments}')
|
152 |
+
|
153 |
+
if L >= seq_len + forecast_horizon:
|
154 |
+
for j in range(L - seq_len - forecast_horizon + 1):
|
155 |
+
# Get sequence features
|
156 |
+
seq = group.iloc[j:(j+seq_len)][features_to_scale].to_numpy()
|
157 |
+
|
158 |
+
# Get future time scalar for the forecast horizon
|
159 |
+
future_time = group['time_scalar'].iloc[j + seq_len + forecast_horizon - 1]
|
160 |
+
future_time_feature = np.full((seq_len, 1), future_time)
|
161 |
+
|
162 |
+
# Augment sequence with future time
|
163 |
+
seq_aug = np.hstack((seq, future_time_feature))
|
164 |
+
Xs.append(seq_aug)
|
165 |
+
|
166 |
+
# Target: future velocity
|
167 |
+
target = group[['latitude_velocity_km', 'longitude_velocity_km']].iloc[j + seq_len + forecast_horizon - 1].to_numpy()
|
168 |
+
ys.append(target)
|
169 |
+
|
170 |
+
segments.append(segment_id)
|
171 |
+
|
172 |
+
# Last known position
|
173 |
+
last_pos = group[['latitude_degrees', 'longitude_degrees']].iloc[j + seq_len - 1].to_numpy()
|
174 |
+
last_positions.append(last_pos)
|
175 |
+
|
176 |
+
return (np.array(Xs, dtype=np.float32),
|
177 |
+
np.array(ys, dtype=np.float32),
|
178 |
+
np.array(segments),
|
179 |
+
np.array(last_positions, dtype=np.float32))
|
180 |
+
|
181 |
+
def load_normalization_params(json_path):
|
182 |
+
"""Load normalization parameters from JSON file"""
|
183 |
+
with open(json_path, "r") as f:
|
184 |
+
normalization_params = json.load(f)
|
185 |
+
return normalization_params["feature_mins"], normalization_params["feature_maxs"]
|
186 |
+
|
187 |
+
def minmax_denormalize(scaled_series, feature_min, feature_max):
|
188 |
+
"""Denormalize data using min-max scaling"""
|
189 |
+
return scaled_series * (feature_max - feature_min) + feature_min
|
190 |
+
|
191 |
+
########################################
|
192 |
+
# INFERENCE PIPELINE #
|
193 |
+
########################################
|
194 |
+
|
195 |
+
def run_inference_pipeline(csv_file_path, model_path, normalization_path):
|
196 |
+
"""Complete inference pipeline following Final_inference_maginet.py logic"""
|
197 |
+
|
198 |
+
try:
|
199 |
+
# Step 1: Load and validate data
|
200 |
+
update_progress('Loading data', 5, 'Reading CSV file...')
|
201 |
+
|
202 |
+
# Enhanced CSV parsing with error handling
|
203 |
+
try:
|
204 |
+
# Determine separator by reading first few lines
|
205 |
+
with open(csv_file_path, 'r') as f:
|
206 |
+
first_line = f.readline()
|
207 |
+
separator = ';' if ';' in first_line else ','
|
208 |
+
|
209 |
+
# Try reading with detected separator
|
210 |
+
df = pd.read_csv(csv_file_path, sep=separator, on_bad_lines='skip')
|
211 |
+
update_progress('Loading data', 8, f'Loaded {len(df)} rows with separator "{separator}"')
|
212 |
+
|
213 |
+
# Debug: Print actual column names
|
214 |
+
print(f"π CSV COLUMNS FOUND: {list(df.columns)}")
|
215 |
+
update_progress('Loading data', 8.5, f'Columns: {list(df.columns)}')
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
print(f"β CSV PARSING ERROR: {e}")
|
219 |
+
# Try alternative parsing methods
|
220 |
+
try:
|
221 |
+
df = pd.read_csv(csv_file_path, sep=',', on_bad_lines='skip')
|
222 |
+
update_progress('Loading data', 8, f'Loaded {len(df)} rows with comma separator (fallback)')
|
223 |
+
print(f"π CSV COLUMNS FOUND (fallback): {list(df.columns)}")
|
224 |
+
except Exception as e2:
|
225 |
+
try:
|
226 |
+
df = pd.read_csv(csv_file_path, sep=';', on_bad_lines='skip')
|
227 |
+
update_progress('Loading data', 8, f'Loaded {len(df)} rows with semicolon separator (fallback)')
|
228 |
+
print(f"π CSV COLUMNS FOUND (fallback): {list(df.columns)}")
|
229 |
+
except Exception as e3:
|
230 |
+
raise ValueError(f"Could not parse CSV file. Tried multiple separators. Errors: {e}, {e2}, {e3}")
|
231 |
+
|
232 |
+
# CRITICAL: Create time_scalar (was missing from inference dataset!)
|
233 |
+
if 'time_scalar' not in df.columns:
|
234 |
+
if 'datetime' in df.columns:
|
235 |
+
# Convert datetime to time_scalar (preferred method)
|
236 |
+
df['datetime'] = pd.to_datetime(df['datetime'], errors='coerce')
|
237 |
+
reference_date = pd.Timestamp('2023-01-01')
|
238 |
+
df['time_scalar'] = ((df['datetime'] - reference_date) / pd.Timedelta(days=1)).round(8)
|
239 |
+
update_progress('Loading data', 9, 'Created time_scalar from datetime column')
|
240 |
+
elif 'time_decimal' in df.columns:
|
241 |
+
# Use time_decimal directly as time_scalar (alternative method)
|
242 |
+
df['time_scalar'] = df['time_decimal'].copy()
|
243 |
+
update_progress('Loading data', 9, 'Created time_scalar from time_decimal column')
|
244 |
+
elif all(col in df.columns for col in ['day', 'month', 'time_decimal']):
|
245 |
+
# Create datetime from components and then time_scalar
|
246 |
+
df['year'] = df.get('year', 2024) # Default year if not present
|
247 |
+
df['datetime'] = pd.to_datetime(df[['year', 'month', 'day']], errors='coerce')
|
248 |
+
df['datetime'] += pd.to_timedelta(df['time_decimal'], unit='h')
|
249 |
+
reference_date = pd.Timestamp('2023-01-01')
|
250 |
+
df['time_scalar'] = ((df['datetime'] - reference_date) / pd.Timedelta(days=1)).round(8)
|
251 |
+
update_progress('Loading data', 9, 'Created time_scalar from day/month/time_decimal')
|
252 |
+
else:
|
253 |
+
# Create a simple sequential time_scalar based on row order
|
254 |
+
df['time_scalar'] = df.index / len(df)
|
255 |
+
update_progress('Loading data', 9, 'Created sequential time_scalar')
|
256 |
+
|
257 |
+
# Validate required columns with detailed error reporting
|
258 |
+
required_columns = [
|
259 |
+
'segment', 'latitude_velocity_km', 'longitude_velocity_km',
|
260 |
+
'latitude_degrees', 'longitude_degrees', 'time_difference_hours', 'time_scalar'
|
261 |
+
]
|
262 |
+
|
263 |
+
print(f"π REQUIRED COLUMNS: {required_columns}")
|
264 |
+
print(f"π ACTUAL COLUMNS: {list(df.columns)}")
|
265 |
+
|
266 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
267 |
+
if missing_columns:
|
268 |
+
available_cols = list(df.columns)
|
269 |
+
error_msg = f"""
|
270 |
+
β COLUMN VALIDATION ERROR:
|
271 |
+
Missing required columns: {missing_columns}
|
272 |
+
Available columns: {available_cols}
|
273 |
+
|
274 |
+
Column mapping suggestions:
|
275 |
+
- Check for extra spaces or different naming
|
276 |
+
- Verify CSV file format and encoding
|
277 |
+
- Ensure time_scalar column exists or can be created
|
278 |
+
"""
|
279 |
+
print(error_msg)
|
280 |
+
raise ValueError(f"Missing required columns: {missing_columns}. Available: {available_cols}")
|
281 |
+
|
282 |
+
# CRITICAL: Apply the SAME data filtering as training/notebook
|
283 |
+
update_progress('Filtering data', 10, 'Applying quality filters...')
|
284 |
+
original_count = len(df)
|
285 |
+
|
286 |
+
# 1. Calculate speed column if missing (CRITICAL!)
|
287 |
+
if 'speed_km_h' not in df.columns:
|
288 |
+
df['speed_km_h'] = np.sqrt(df['latitude_velocity_km']**2 + df['longitude_velocity_km']**2)
|
289 |
+
update_progress('Filtering data', 10.5, 'Calculated speed_km_h column')
|
290 |
+
|
291 |
+
# 2. Speed filtering - EXACTLY like training
|
292 |
+
df = df[(df['speed_km_h'] >= 2) & (df['speed_km_h'] <= 60)].copy()
|
293 |
+
update_progress('Filtering data', 11, f'Speed filter: {original_count} -> {len(df)} rows')
|
294 |
+
|
295 |
+
# 3. Velocity filtering - CRITICAL for performance!
|
296 |
+
velocity_mask = (
|
297 |
+
(np.abs(df['latitude_velocity_km']) <= 100) &
|
298 |
+
(np.abs(df['longitude_velocity_km']) <= 100) &
|
299 |
+
(df['time_difference_hours'] > 0) &
|
300 |
+
(df['time_difference_hours'] <= 24) # Max 24 hours between points
|
301 |
+
)
|
302 |
+
df = df[velocity_mask].copy()
|
303 |
+
update_progress('Filtering data', 12, f'Velocity filter: -> {len(df)} rows')
|
304 |
+
|
305 |
+
# 4. Segment length filtering - Remove segments with < 20 points
|
306 |
+
segment_counts = df['segment'].value_counts()
|
307 |
+
segments_to_remove = segment_counts[segment_counts < 20].index
|
308 |
+
before_segment_filter = len(df)
|
309 |
+
df = df[~df['segment'].isin(segments_to_remove)].copy()
|
310 |
+
update_progress('Filtering data', 13, f'Segment filter: {before_segment_filter} -> {len(df)} rows')
|
311 |
+
|
312 |
+
# 5. Remove NaN and infinite values
|
313 |
+
df = df.dropna().copy()
|
314 |
+
numeric_cols = ['latitude_velocity_km', 'longitude_velocity_km', 'time_difference_hours']
|
315 |
+
for col in numeric_cols:
|
316 |
+
if col in df.columns:
|
317 |
+
df = df[~np.isinf(df[col])].copy()
|
318 |
+
|
319 |
+
# DEBUGGING: Add detailed filtering statistics
|
320 |
+
filtered_count = len(df)
|
321 |
+
filter_percent = ((original_count - filtered_count) / original_count) * 100
|
322 |
+
update_progress('Filtering data', 14, f'Final filtered data: {filtered_count} rows ({original_count - filtered_count} removed = {filter_percent:.1f}%)')
|
323 |
+
|
324 |
+
# Debug info for analysis
|
325 |
+
print(f"π FILTERING SUMMARY:")
|
326 |
+
print(f" Original: {original_count:,} rows")
|
327 |
+
print(f" Final: {filtered_count:,} rows")
|
328 |
+
print(f" Removed: {original_count - filtered_count:,} ({filter_percent:.1f}%)")
|
329 |
+
|
330 |
+
if len(df) == 0:
|
331 |
+
raise ValueError("No data remaining after quality filtering. Check your input data quality.")
|
332 |
+
|
333 |
+
# Step 2: Load normalization parameters
|
334 |
+
update_progress('Loading normalization', 12, 'Loading normalization parameters...')
|
335 |
+
feature_mins, feature_maxs = load_normalization_params(normalization_path)
|
336 |
+
|
337 |
+
# Step 2.5: CRITICAL - Normalize the test data (missing step causing 3373km error!)
|
338 |
+
update_progress('Normalizing data', 15, 'Applying normalization to test data...')
|
339 |
+
features_to_normalize = ['latitude_velocity_km', 'longitude_velocity_km',
|
340 |
+
'latitude_degrees', 'longitude_degrees',
|
341 |
+
'time_difference_hours', 'time_scalar']
|
342 |
+
|
343 |
+
for feature in features_to_normalize:
|
344 |
+
if feature in df.columns and feature in feature_mins:
|
345 |
+
min_val = feature_mins[feature]
|
346 |
+
max_val = feature_maxs[feature]
|
347 |
+
rng = max_val - min_val if max_val != min_val else 1
|
348 |
+
df[feature] = (df[feature] - min_val) / rng
|
349 |
+
update_progress('Normalizing data', 18, f'Normalized {feature}')
|
350 |
+
|
351 |
+
# Step 3: Create sequences
|
352 |
+
SEQ_LENGTH = 12
|
353 |
+
FORECAST_HORIZON = 1
|
354 |
+
|
355 |
+
X_test, y_test, test_segments, last_known_positions_scaled = create_sequences_grouped_by_segment_lat_long_veloc(
|
356 |
+
df, seq_len=SEQ_LENGTH, forecast_horizon=FORECAST_HORIZON
|
357 |
+
)
|
358 |
+
|
359 |
+
update_progress('Preparing model', 45, f'Created {len(X_test)} sequences')
|
360 |
+
|
361 |
+
if len(X_test) == 0:
|
362 |
+
raise ValueError("No valid sequences could be created. Check your data and sequence length requirements.")
|
363 |
+
|
364 |
+
# Step 4: Prepare data for model
|
365 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
366 |
+
X_test_tensor = torch.from_numpy(X_test).float().to(device)
|
367 |
+
y_test_tensor = torch.from_numpy(y_test).float().to(device)
|
368 |
+
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
|
369 |
+
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
|
370 |
+
|
371 |
+
# Step 5: Load model
|
372 |
+
update_progress('Loading model', 50, 'Loading trained model...')
|
373 |
+
|
374 |
+
# CRITICAL: Model expects 6 features (segment removed) + 1 future_time = 7 total
|
375 |
+
in_dim = X_test.shape[2] - 1 # Remove segment column dimension
|
376 |
+
# CRITICAL: Match the exact model architecture from Atlantic model weights
|
377 |
+
hidden_dim = 250 # From best_model.pth
|
378 |
+
n_layers = 7 # From best_model.pth (CRITICAL: not 10!)
|
379 |
+
dropout = 0.2
|
380 |
+
|
381 |
+
model = LSTMWithAttentionWithResid(
|
382 |
+
in_dim, hidden_dim, FORECAST_HORIZON,
|
383 |
+
n_layers=n_layers, dropout=dropout
|
384 |
+
).to(device)
|
385 |
+
|
386 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
387 |
+
model.eval()
|
388 |
+
|
389 |
+
# Step 6: Run inference
|
390 |
+
update_progress('Running inference', 60, 'Making predictions...')
|
391 |
+
|
392 |
+
# CRITICAL: Extract features batch-by-batch like your notebook
|
393 |
+
all_preds = []
|
394 |
+
segments_extracted = []
|
395 |
+
time_scalars_extracted = []
|
396 |
+
time_diff_hours_extracted = []
|
397 |
+
|
398 |
+
with torch.no_grad():
|
399 |
+
for i, batch in enumerate(test_loader):
|
400 |
+
x_batch, _ = batch
|
401 |
+
|
402 |
+
# CRITICAL: Extract features exactly like your notebook
|
403 |
+
segment_batch = x_batch[:, 0, 0].cpu().numpy() # Take segment from first time step
|
404 |
+
time_scalar_batch = x_batch[:, -1, 6].cpu().numpy() # LAST timestep, index 6 = time_scalar
|
405 |
+
time_diff_hours_batch = x_batch[:, 0, 5].cpu().numpy() # First timestep, index 5
|
406 |
+
|
407 |
+
segments_extracted.extend(segment_batch)
|
408 |
+
time_scalars_extracted.extend(time_scalar_batch)
|
409 |
+
time_diff_hours_extracted.extend(time_diff_hours_batch)
|
410 |
+
|
411 |
+
# Remove segment column before model input
|
412 |
+
x_batch_no_segment = x_batch[:, :, 1:] # Remove segment (index 0) but keep all other features
|
413 |
+
preds = model(x_batch_no_segment)
|
414 |
+
all_preds.append(preds.cpu().numpy())
|
415 |
+
|
416 |
+
# Progress update
|
417 |
+
progress = 60 + (i / len(test_loader)) * 20 # 60-80% range
|
418 |
+
update_progress('Running inference', progress,
|
419 |
+
f'Processing batch {i+1}/{len(test_loader)}')
|
420 |
+
|
421 |
+
all_preds = np.concatenate(all_preds, axis=0)
|
422 |
+
|
423 |
+
# Step 7: Process results
|
424 |
+
update_progress('Processing results', 80, 'Processing predictions...')
|
425 |
+
|
426 |
+
# CRITICAL: Reshape predictions exactly like your notebook
|
427 |
+
yhat = torch.from_numpy(all_preds)
|
428 |
+
yhat = yhat.view(-1, 2) # Reshape to [batch_size, 2] - EXACTLY like your notebook
|
429 |
+
|
430 |
+
# Extract predictions exactly like your notebook
|
431 |
+
predicted_lat_vel = yhat[:, 0].numpy() # Predicted lat velocity
|
432 |
+
predicted_lon_vel = yhat[:, 1].numpy() # Predicted lon velocity
|
433 |
+
|
434 |
+
# Extract actual values exactly like your notebook
|
435 |
+
y_real = y_test_tensor.cpu()
|
436 |
+
actual_lat_vel = y_real[:, 0].numpy() # Actual lat velocity
|
437 |
+
actual_lon_vel = y_real[:, 1].numpy() # Actual lon velocity
|
438 |
+
|
439 |
+
# CRITICAL: Use extracted features from batches (matching your notebook exactly)
|
440 |
+
# Ensure all arrays have consistent length
|
441 |
+
num_samples = len(predicted_lat_vel)
|
442 |
+
segments_extracted = segments_extracted[:num_samples]
|
443 |
+
time_scalars_extracted = time_scalars_extracted[:num_samples]
|
444 |
+
time_diff_hours_extracted = time_diff_hours_extracted[:num_samples]
|
445 |
+
last_known_positions_scaled = last_known_positions_scaled[:num_samples]
|
446 |
+
|
447 |
+
# Create results dataframe exactly like your notebook
|
448 |
+
results_df = pd.DataFrame({
|
449 |
+
'segment': segments_extracted, # From batch extraction
|
450 |
+
'time_difference_hours': time_diff_hours_extracted, # From batch extraction (first timestep)
|
451 |
+
'Time Scalar': time_scalars_extracted, # From batch extraction (LAST timestep)
|
452 |
+
'Last Known Latitude': [pos[0] for pos in last_known_positions_scaled],
|
453 |
+
'Last Known Longitude': [pos[1] for pos in last_known_positions_scaled],
|
454 |
+
'predicted_lat_km': predicted_lat_vel,
|
455 |
+
'predicted_lon_km': predicted_lon_vel,
|
456 |
+
'actual_lat_km': actual_lat_vel,
|
457 |
+
'actual_lon_km': actual_lon_vel
|
458 |
+
})
|
459 |
+
|
460 |
+
# Step 8: Denormalize results
|
461 |
+
update_progress('Denormalizing results', 85, 'Converting to real units...')
|
462 |
+
|
463 |
+
# Column to feature mapping (COMPLETE mapping for all denormalizable columns)
|
464 |
+
column_to_feature = {
|
465 |
+
"predicted_lat_km": "latitude_velocity_km",
|
466 |
+
"predicted_lon_km": "longitude_velocity_km",
|
467 |
+
"actual_lat_km": "latitude_velocity_km",
|
468 |
+
"actual_lon_km": "longitude_velocity_km",
|
469 |
+
"Last Known Latitude": "latitude_degrees",
|
470 |
+
"Last Known Longitude": "longitude_degrees",
|
471 |
+
"time_difference_hours": "time_difference_hours",
|
472 |
+
"Time Scalar": "time_scalar"
|
473 |
+
}
|
474 |
+
|
475 |
+
# Denormalize relevant columns
|
476 |
+
for col, feat in column_to_feature.items():
|
477 |
+
if col in results_df.columns and feat in feature_mins:
|
478 |
+
fmin = feature_mins[feat]
|
479 |
+
fmax = feature_maxs[feat]
|
480 |
+
results_df[col + "_unscaled"] = minmax_denormalize(results_df[col], fmin, fmax)
|
481 |
+
update_progress('Denormalizing results', 85, f'Denormalized {col}')
|
482 |
+
|
483 |
+
# Ensure all required _unscaled columns exist
|
484 |
+
required_unscaled_cols = [
|
485 |
+
'predicted_lat_km_unscaled', 'predicted_lon_km_unscaled',
|
486 |
+
'actual_lat_km_unscaled', 'actual_lon_km_unscaled',
|
487 |
+
'Last Known Latitude_unscaled', 'Last Known Longitude_unscaled',
|
488 |
+
'time_difference_hours_unscaled'
|
489 |
+
]
|
490 |
+
|
491 |
+
for col in required_unscaled_cols:
|
492 |
+
if col not in results_df.columns:
|
493 |
+
base_col = col.replace('_unscaled', '')
|
494 |
+
if base_col in results_df.columns:
|
495 |
+
# If base column exists but wasn't denormalized, copy it
|
496 |
+
results_df[col] = results_df[base_col]
|
497 |
+
update_progress('Denormalizing results', 87, f'Created missing {col}')
|
498 |
+
else:
|
499 |
+
results_df[col] = 0.0
|
500 |
+
update_progress('Denormalizing results', 87, f'Defaulted missing {col} to 0')
|
501 |
+
|
502 |
+
# ---------------------------
|
503 |
+
# NEW: Clip predicted velocities to realistic physical bounds to avoid huge errors
|
504 |
+
# ---------------------------
|
505 |
+
VELOCITY_RANGE_KM_H = (-100, 100) # Same limits used during input filtering
|
506 |
+
results_df["predicted_lat_km_unscaled"] = results_df["predicted_lat_km_unscaled"].clip(*VELOCITY_RANGE_KM_H)
|
507 |
+
results_df["predicted_lon_km_unscaled"] = results_df["predicted_lon_km_unscaled"].clip(*VELOCITY_RANGE_KM_H)
|
508 |
+
update_progress('Denormalizing results', 88, 'Clipped predicted velocities to realistic range')
|
509 |
+
|
510 |
+
# Step 9: Calculate final positions and errors (EXACT column structure matching your notebook)
|
511 |
+
update_progress('Calculating errors', 90, 'Computing prediction errors...')
|
512 |
+
|
513 |
+
# Compute displacement components (in km)
|
514 |
+
results_df["pred_final_lat_km_component"] = (
|
515 |
+
results_df["predicted_lat_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
516 |
+
)
|
517 |
+
results_df["pred_final_lon_km_component"] = (
|
518 |
+
results_df["predicted_lon_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
519 |
+
)
|
520 |
+
results_df["actual_final_lat_km_component"] = (
|
521 |
+
results_df["actual_lat_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
522 |
+
)
|
523 |
+
results_df["actual_final_lon_km_component"] = (
|
524 |
+
results_df["actual_lon_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
525 |
+
)
|
526 |
+
|
527 |
+
# Calculate total displacement magnitudes (MISSING COLUMNS!)
|
528 |
+
results_df["pred_final_km"] = np.sqrt(
|
529 |
+
results_df["pred_final_lat_km_component"]**2 + results_df["pred_final_lon_km_component"]**2
|
530 |
+
)
|
531 |
+
results_df["actual_final_km"] = np.sqrt(
|
532 |
+
results_df["actual_final_lat_km_component"]**2 + results_df["actual_final_lon_km_component"]**2
|
533 |
+
)
|
534 |
+
|
535 |
+
# Calculate Euclidean distance error (in km)
|
536 |
+
results_df["error_km"] = np.sqrt(
|
537 |
+
(results_df["pred_final_lat_km_component"] - results_df["actual_final_lat_km_component"])**2 +
|
538 |
+
(results_df["pred_final_lon_km_component"] - results_df["actual_final_lon_km_component"])**2
|
539 |
+
)
|
540 |
+
|
541 |
+
# Compute final positions in degrees
|
542 |
+
km_per_deg_lat = 111 # approximate conversion for latitude
|
543 |
+
results_df["pred_final_lat_deg"] = results_df["Last Known Latitude_unscaled"] + (
|
544 |
+
results_df["predicted_lat_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
545 |
+
) / km_per_deg_lat
|
546 |
+
results_df["actual_final_lat_deg"] = results_df["Last Known Latitude_unscaled"] + (
|
547 |
+
results_df["actual_lat_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
548 |
+
) / km_per_deg_lat
|
549 |
+
|
550 |
+
# Account for longitude scaling by latitude
|
551 |
+
results_df["Last_Known_Lat_rad"] = np.deg2rad(results_df["Last Known Latitude_unscaled"])
|
552 |
+
results_df["pred_final_lon_deg"] = results_df["Last Known Longitude_unscaled"] + (
|
553 |
+
results_df["predicted_lon_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
554 |
+
) / (km_per_deg_lat * np.cos(results_df["Last_Known_Lat_rad"]))
|
555 |
+
results_df["actual_final_lon_deg"] = results_df["Last Known Longitude_unscaled"] + (
|
556 |
+
results_df["actual_lon_km_unscaled"] * results_df["time_difference_hours_unscaled"]
|
557 |
+
) / (km_per_deg_lat * np.cos(results_df["Last_Known_Lat_rad"]))
|
558 |
+
|
559 |
+
# Step 10: Reorder columns to match your EXACT specification
|
560 |
+
update_progress('Finalizing results', 93, 'Reordering columns to match notebook format...')
|
561 |
+
|
562 |
+
# EXACT column order as specified by user
|
563 |
+
column_order = [
|
564 |
+
'segment', 'time_difference_hours', 'Time Scalar', 'Last Known Latitude', 'Last Known Longitude',
|
565 |
+
'predicted_lat_km', 'predicted_lon_km', 'actual_lat_km', 'actual_lon_km',
|
566 |
+
'predicted_lat_km_unscaled', 'predicted_lon_km_unscaled', 'actual_lat_km_unscaled', 'actual_lon_km_unscaled',
|
567 |
+
'Last Known Latitude_unscaled', 'Last Known Longitude_unscaled', 'time_difference_hours_unscaled',
|
568 |
+
'pred_final_km', 'actual_final_km',
|
569 |
+
'pred_final_lat_km_component', 'pred_final_lon_km_component',
|
570 |
+
'actual_final_lat_km_component', 'actual_final_lon_km_component',
|
571 |
+
'error_km', 'pred_final_lat_deg', 'actual_final_lat_deg', 'Last_Known_Lat_rad',
|
572 |
+
'pred_final_lon_deg', 'actual_final_lon_deg'
|
573 |
+
]
|
574 |
+
|
575 |
+
# Validate all required columns exist - add missing ones with defaults if needed
|
576 |
+
missing_columns = [col for col in column_order if col not in results_df.columns]
|
577 |
+
if missing_columns:
|
578 |
+
update_progress('Finalizing results', 94, f'Adding missing columns: {missing_columns}')
|
579 |
+
for col in missing_columns:
|
580 |
+
# Add default values for any missing columns
|
581 |
+
if '_unscaled' in col:
|
582 |
+
# For unscaled columns, try to find the original scaled column
|
583 |
+
base_col = col.replace('_unscaled', '')
|
584 |
+
if base_col in results_df.columns and base_col in column_to_feature:
|
585 |
+
# Use the same denormalization process
|
586 |
+
feat = column_to_feature[base_col]
|
587 |
+
if feat in feature_mins:
|
588 |
+
fmin = feature_mins[feat]
|
589 |
+
fmax = feature_maxs[feat]
|
590 |
+
results_df[col] = minmax_denormalize(results_df[base_col], fmin, fmax)
|
591 |
+
else:
|
592 |
+
results_df[col] = results_df[base_col] # No denormalization available
|
593 |
+
else:
|
594 |
+
results_df[col] = 0.0 # Default to 0
|
595 |
+
else:
|
596 |
+
results_df[col] = 0.0 # Default to 0 for any other missing columns
|
597 |
+
|
598 |
+
# Reorder columns to match exact specification
|
599 |
+
results_df = results_df[column_order]
|
600 |
+
|
601 |
+
# Step 11: Save results
|
602 |
+
update_progress('Saving results', 95, 'Saving inference results...')
|
603 |
+
|
604 |
+
# Create results directory
|
605 |
+
results_dir = Path('results/inference_atlantic')
|
606 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
607 |
+
|
608 |
+
# Save to results directory
|
609 |
+
timestamp = pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')
|
610 |
+
results_file = results_dir / f'inference_results_{timestamp}.csv'
|
611 |
+
results_df.to_csv(results_file, index=False)
|
612 |
+
|
613 |
+
# Also save to temporary file for compatibility
|
614 |
+
output_file = tempfile.NamedTemporaryFile(
|
615 |
+
mode='w', suffix='_inference_results.csv', delete=False
|
616 |
+
)
|
617 |
+
results_df.to_csv(output_file.name, index=False)
|
618 |
+
|
619 |
+
# CRITICAL: Calculate SAME regression metrics as your notebook
|
620 |
+
# Convert predictions and actuals to tensors for metric calculation
|
621 |
+
yhat_tensor = torch.from_numpy(np.column_stack([predicted_lat_vel, predicted_lon_vel])).float()
|
622 |
+
y_real_tensor = torch.from_numpy(np.column_stack([actual_lat_vel, actual_lon_vel])).float()
|
623 |
+
|
624 |
+
# Calculate regression metrics exactly like your notebook
|
625 |
+
def calc_metrics_like_notebook(preds, labels):
|
626 |
+
"""Calculate metrics exactly like your notebook's calc_metrics function"""
|
627 |
+
EPS = 1e-8
|
628 |
+
mse = torch.mean((preds - labels) ** 2)
|
629 |
+
mae = torch.mean(torch.abs(preds - labels))
|
630 |
+
rmse = torch.sqrt(mse)
|
631 |
+
mape = torch.mean(torch.abs((preds - labels) / (labels + EPS))) * 100 # Convert to percentage
|
632 |
+
rse = torch.sum((preds - labels) ** 2) / torch.sum((labels + EPS) ** 2)
|
633 |
+
return rse.item(), mae.item(), mse.item(), mape.item(), rmse.item()
|
634 |
+
|
635 |
+
# Calculate regression metrics on velocity predictions
|
636 |
+
rse, mae, mse, mape, rmse = calc_metrics_like_notebook(yhat_tensor, y_real_tensor)
|
637 |
+
|
638 |
+
# Calculate summary statistics
|
639 |
+
error_stats = {
|
640 |
+
# Distance-based metrics (web app specific)
|
641 |
+
'mean_error_km': float(results_df["error_km"].mean()),
|
642 |
+
'median_error_km': float(results_df["error_km"].median()),
|
643 |
+
'std_error_km': float(results_df["error_km"].std()),
|
644 |
+
'min_error_km': float(results_df["error_km"].min()),
|
645 |
+
'max_error_km': float(results_df["error_km"].max()),
|
646 |
+
|
647 |
+
# Regression metrics (matching your notebook)
|
648 |
+
'rse': rse,
|
649 |
+
'mae': mae,
|
650 |
+
'mse': mse,
|
651 |
+
'mape': mape,
|
652 |
+
'rmse': rmse,
|
653 |
+
|
654 |
+
# General stats
|
655 |
+
'total_predictions': len(results_df),
|
656 |
+
'total_segments': len(results_df['segment'].unique()),
|
657 |
+
'columns_generated': list(results_df.columns),
|
658 |
+
'total_columns': len(results_df.columns)
|
659 |
+
}
|
660 |
+
|
661 |
+
# NEW: Create histogram of error distribution (30 bins by default)
|
662 |
+
hist_counts, bin_edges = np.histogram(results_df["error_km"], bins=30)
|
663 |
+
histogram_data = {
|
664 |
+
'bins': bin_edges.tolist(),
|
665 |
+
'counts': hist_counts.tolist()
|
666 |
+
}
|
667 |
+
|
668 |
+
update_progress('Complete', 100,
|
669 |
+
f'β
Inference complete! Distance: {error_stats["mean_error_km"]:.2f} km | MAE: {error_stats["mae"]:.2f} | MAPE: {error_stats["mape"]:.2f}%')
|
670 |
+
|
671 |
+
# Emit inference_complete with full statistics and histogram for the frontend chart
|
672 |
+
try:
|
673 |
+
socketio.emit('inference_complete', {
|
674 |
+
'success': True,
|
675 |
+
'stats': error_stats,
|
676 |
+
'histogram': histogram_data
|
677 |
+
})
|
678 |
+
except Exception:
|
679 |
+
pass # In case we are in CLI context without SocketIO
|
680 |
+
|
681 |
+
return {
|
682 |
+
'success': True,
|
683 |
+
'results_file': output_file.name,
|
684 |
+
'stats': error_stats,
|
685 |
+
'histogram': histogram_data,
|
686 |
+
'message': f'Successfully processed {len(results_df)} predictions'
|
687 |
+
}
|
688 |
+
|
689 |
+
except Exception as e:
|
690 |
+
error_msg = f"Error during inference: {str(e)}"
|
691 |
+
update_progress('Error', 0, error_msg)
|
692 |
+
return {
|
693 |
+
'success': False,
|
694 |
+
'error': error_msg
|
695 |
+
}
|
696 |
+
|
697 |
+
########################################
|
698 |
+
# WEB ROUTES #
|
699 |
+
########################################
|
700 |
+
|
701 |
+
@app.route('/')
|
702 |
+
def index():
|
703 |
+
return render_template('vessel_inference.html')
|
704 |
+
|
705 |
+
@app.route('/upload', methods=['POST'])
|
706 |
+
def upload_file():
|
707 |
+
try:
|
708 |
+
# Check if files were uploaded
|
709 |
+
if 'csv_file' not in request.files:
|
710 |
+
return jsonify({'success': False, 'error': 'No CSV file uploaded'})
|
711 |
+
|
712 |
+
csv_file = request.files['csv_file']
|
713 |
+
if csv_file.filename == '':
|
714 |
+
return jsonify({'success': False, 'error': 'No CSV file selected'})
|
715 |
+
|
716 |
+
# Default model and normalization files
|
717 |
+
model_path = 'best_model.pth'
|
718 |
+
normalization_path = 'normalization_params_1_atlanttic_regular_intervals_with_lat_lon_velocity_and_time_difference_filter_outlier_segment_min_20_points.json'
|
719 |
+
|
720 |
+
# Check for optional uploads
|
721 |
+
if 'model_file' in request.files and request.files['model_file'].filename != '':
|
722 |
+
model_file = request.files['model_file']
|
723 |
+
model_filename = secure_filename(model_file.filename)
|
724 |
+
model_path = os.path.join(app.config['UPLOAD_FOLDER'], model_filename)
|
725 |
+
model_file.save(model_path)
|
726 |
+
|
727 |
+
if 'normalization_file' in request.files and request.files['normalization_file'].filename != '':
|
728 |
+
norm_file = request.files['normalization_file']
|
729 |
+
norm_filename = secure_filename(norm_file.filename)
|
730 |
+
normalization_path = os.path.join(app.config['UPLOAD_FOLDER'], norm_filename)
|
731 |
+
norm_file.save(normalization_path)
|
732 |
+
|
733 |
+
# Check if required files exist
|
734 |
+
if not os.path.exists(model_path):
|
735 |
+
return jsonify({'success': False, 'error': f'Model file not found: {model_path}'})
|
736 |
+
|
737 |
+
if not os.path.exists(normalization_path):
|
738 |
+
return jsonify({'success': False, 'error': f'Normalization file not found: {normalization_path}'})
|
739 |
+
|
740 |
+
# Save CSV file
|
741 |
+
csv_filename = secure_filename(csv_file.filename)
|
742 |
+
csv_path = os.path.join(app.config['UPLOAD_FOLDER'], csv_filename)
|
743 |
+
csv_file.save(csv_path)
|
744 |
+
|
745 |
+
# Start inference in background thread
|
746 |
+
def run_inference_background():
|
747 |
+
return run_inference_pipeline(csv_path, model_path, normalization_path)
|
748 |
+
|
749 |
+
thread = threading.Thread(target=run_inference_background)
|
750 |
+
thread.start()
|
751 |
+
|
752 |
+
return jsonify({'success': True, 'message': 'Files uploaded successfully. Inference started.'})
|
753 |
+
|
754 |
+
except Exception as e:
|
755 |
+
return jsonify({'success': False, 'error': str(e)})
|
756 |
+
|
757 |
+
@app.route('/progress')
|
758 |
+
def get_progress():
|
759 |
+
return jsonify(current_progress)
|
760 |
+
|
761 |
+
@app.route('/download_results')
|
762 |
+
def download_results():
|
763 |
+
# Find the most recent results file
|
764 |
+
upload_dir = app.config['UPLOAD_FOLDER']
|
765 |
+
temp_dir = tempfile.gettempdir()
|
766 |
+
|
767 |
+
# Look for results files in both directories
|
768 |
+
for directory in [upload_dir, temp_dir]:
|
769 |
+
if os.path.exists(directory):
|
770 |
+
files = [f for f in os.listdir(directory) if f.endswith('_inference_results.csv')]
|
771 |
+
if files:
|
772 |
+
latest_file = max(files, key=lambda x: os.path.getctime(os.path.join(directory, x)))
|
773 |
+
return send_file(
|
774 |
+
os.path.join(directory, latest_file),
|
775 |
+
as_attachment=True,
|
776 |
+
download_name='vessel_inference_results.csv'
|
777 |
+
)
|
778 |
+
|
779 |
+
return jsonify({'error': 'No results file found'}), 404
|
780 |
+
|
781 |
+
########################################
|
782 |
+
# SOCKETIO EVENTS #
|
783 |
+
########################################
|
784 |
+
|
785 |
+
@socketio.on('connect')
|
786 |
+
def handle_connect():
|
787 |
+
emit('progress_update', current_progress)
|
788 |
+
|
789 |
+
@socketio.on('start_inference')
|
790 |
+
def handle_start_inference(data):
|
791 |
+
"""Handle inference request via WebSocket"""
|
792 |
+
try:
|
793 |
+
csv_path = data.get('csv_path')
|
794 |
+
model_path = data.get('model_path', 'best_model.pth')
|
795 |
+
norm_path = data.get('normalization_path', 'normalization_params_1_atlanttic_regular_intervals_with_lat_lon_velocity_and_time_difference_filter_outlier_segment_min_20_points.json')
|
796 |
+
|
797 |
+
def run_inference_background():
|
798 |
+
result = run_inference_pipeline(csv_path, model_path, norm_path)
|
799 |
+
emit('inference_complete', result)
|
800 |
+
|
801 |
+
thread = threading.Thread(target=run_inference_background)
|
802 |
+
thread.start()
|
803 |
+
|
804 |
+
except Exception as e:
|
805 |
+
emit('inference_complete', {'success': False, 'error': str(e)})
|
806 |
+
|
807 |
+
if __name__ == '__main__':
|
808 |
+
print("π’ Vessel Trajectory Inference Web App")
|
809 |
+
print("π Using Final_inference_maginet.py logic")
|
810 |
+
|
811 |
+
# Get port from environment variable (Hugging Face Spaces uses 7860)
|
812 |
+
port = int(os.environ.get('PORT', 7860))
|
813 |
+
print(f"π Starting server at http://0.0.0.0:{port}")
|
814 |
+
print("π Make sure you have:")
|
815 |
+
print(" - best_model.pth")
|
816 |
+
print(" - normalization_params_1_atlanttic_regular_intervals_...json")
|
817 |
+
print(" - Your test dataset CSV")
|
818 |
+
|
819 |
+
socketio.run(app, host='0.0.0.0', port=port, debug=False)
|