headscratchertm commited on
Commit
95e206f
·
1 Parent(s): 79a8e73

loaded FC and AT

Browse files
Files changed (2) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. main.py +19 -2
__pycache__/model.cpython-310.pyc ADDED
Binary file (9.81 kB). View file
 
main.py CHANGED
@@ -1,8 +1,25 @@
1
  import torch
 
 
 
 
 
 
 
 
 
 
2
  def solve():
3
  checkpoint=torch.load("SuperSloMo.ckpt")
4
- checkpoint.eval()
5
- print(checkpoint)
 
 
 
 
 
 
 
6
  def main():
7
  solve()
8
  if __name__=="__main__":
 
1
  import torch
2
+ from model import UNet
3
+ from PIL import Image
4
+ from torchvision.transforms import transforms,ToTensor
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ def load_frames(path):
9
+ image=Image.open(path).convert('RGB')
10
+ tensor=ToTensor()
11
+ return tensor(image).unsqueeze(0).to(device)
12
  def solve():
13
  checkpoint=torch.load("SuperSloMo.ckpt")
14
+ model_FC=UNet(6,4) # initialize ARCH
15
+ model_FC=model_FC.to(device)# reassign model tensors
16
+ model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
17
+ model_AT=UNet(20,5)
18
+ model_AT.load_state_dict(checkpoint["state_dictAT"])
19
+ model_AT=model_AT.to(device)
20
+ model_AT.eval()
21
+ model_FC.eval()
22
+
23
  def main():
24
  solve()
25
  if __name__=="__main__":