File size: 595 Bytes
100edb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import streamlit as st
from aurora import rollout, Aurora

def run_inference(selected_model, model, batch, device):
    if selected_model == "Prithvi":
        model.eval()
        with torch.no_grad():
            out = model(batch)
        return out
    elif selected_model == "Aurora":
        model.eval()
        with torch.inference_mode():
            # Example: Predict 2 steps ahead
            out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
        return out
    else:
        st.error("Inference not implemented for this model.")
        return None