File size: 595 Bytes
100edb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
import streamlit as st
from aurora import rollout, Aurora
def run_inference(selected_model, model, batch, device):
if selected_model == "Prithvi":
model.eval()
with torch.no_grad():
out = model(batch)
return out
elif selected_model == "Aurora":
model.eval()
with torch.inference_mode():
# Example: Predict 2 steps ahead
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
return out
else:
st.error("Inference not implemented for this model.")
return None
|