import torch import streamlit as st from aurora import rollout, Aurora def run_inference(selected_model, model, batch, device): if selected_model == "Prithvi": model.eval() with torch.no_grad(): out = model(batch) return out elif selected_model == "Aurora": model.eval() with torch.inference_mode(): # Example: Predict 2 steps ahead out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] return out else: st.error("Inference not implemented for this model.") return None