AItool commited on
Commit
682d5a5
·
verified ·
1 Parent(s): 112aa35

Update inference_img.py

Browse files
Files changed (1) hide show
  1. inference_img.py +3 -4
inference_img.py CHANGED
@@ -5,7 +5,6 @@ import argparse
5
  from torch.nn import functional as F
6
  import warnings
7
 
8
-
9
  def main():
10
 
11
  warnings.filterwarnings("ignore")
@@ -18,7 +17,7 @@ def main():
18
 
19
  parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
20
  parser.add_argument('--img', dest='img', nargs=2, required=True)
21
- parser.add_argument('--exp', default=4, type=int, required=True)
22
  parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
23
  parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
24
  parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
@@ -29,10 +28,10 @@ def main():
29
  try:
30
  try:
31
  try:
32
- from model.RIFE_HDv2 import Model
33
  model = Model()
34
  model.load_model(args.modelDir, -1)
35
- print("Loaded v2.x HD model.")
36
  except:
37
  from train_log.RIFE_HDv3 import Model
38
  model = Model()
 
5
  from torch.nn import functional as F
6
  import warnings
7
 
 
8
  def main():
9
 
10
  warnings.filterwarnings("ignore")
 
17
 
18
  parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
19
  parser.add_argument('--img', dest='img', nargs=2, required=True)
20
+ parser.add_argument('--exp', default=4, type=int)
21
  parser.add_argument('--ratio', default=0, type=float, help='inference ratio between two images with 0 - 1 range')
22
  parser.add_argument('--rthreshold', default=0.02, type=float, help='returns image when actual ratio falls in given range threshold')
23
  parser.add_argument('--rmaxcycles', default=8, type=int, help='limit max number of bisectional cycles')
 
28
  try:
29
  try:
30
  try:
31
+ from train_log.RIFE_HDv3 import Model
32
  model = Model()
33
  model.load_model(args.modelDir, -1)
34
+ print("Loaded v3.x HD model.")
35
  except:
36
  from train_log.RIFE_HDv3 import Model
37
  model = Model()