AItool commited on
Commit
25a2142
·
verified ·
1 Parent(s): 7e1f9a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -29
app.py CHANGED
@@ -8,32 +8,12 @@ from pathlib import Path
8
  import gradio as gr
9
  from PIL import Image
10
 
11
- # Config
12
- RIFE_REPO_URL = "https://github.com/hzwer/arXiv2020-RIFE"
13
- RIFE_DIR = Path("arXiv2020-RIFE")
14
- TRAIN_LOG_DIR = RIFE_DIR / "train_log"
15
  OUTPUT_DIR = RIFE_DIR / "output"
16
- WEIGHTS_ID = "1APIzVeI-4ZZCEuIRE1m6WYfSCaOsi_7_" # RIFE v3.6
17
-
18
  LOCK = threading.Lock()
19
- RIFE_READY = False
20
 
21
  def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
22
- global RIFE_READY
23
-
24
  with LOCK:
25
- # Lazy one-time setup
26
- if not RIFE_READY:
27
- if not RIFE_DIR.exists():
28
- subprocess.run(["git", "clone", "--depth", "1", RIFE_REPO_URL, str(RIFE_DIR)], check=True)
29
- TRAIN_LOG_DIR.mkdir(parents=True, exist_ok=True)
30
- has_weights = any(p.suffix in {".pkl", ".pth"} for p in TRAIN_LOG_DIR.glob("*"))
31
- if not has_weights:
32
- zip_path = TRAIN_LOG_DIR / "RIFE_trained_model_v3.6.zip"
33
- subprocess.run(["gdown", "--id", WEIGHTS_ID, "-O", str(zip_path)], check=True)
34
- subprocess.run(["unzip", "-o", str(zip_path), "-d", str(TRAIN_LOG_DIR)], check=True)
35
- RIFE_READY = True
36
-
37
  # Normalize inputs
38
  a = a.convert("RGB")
39
  b = b.convert("RGB")
@@ -46,17 +26,18 @@ def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
46
  a.save(p1, "PNG")
47
  b.save(p2, "PNG")
48
 
49
- # Clean RIFE output and run inference
50
  if OUTPUT_DIR.exists():
51
  shutil.rmtree(OUTPUT_DIR)
52
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
53
 
54
- cmd = ["python3", "inference_img.py", "--img", str(p1), str(p2)]
 
55
  if isinstance(exp, int) and exp >= 1:
56
  cmd += ["--exp", str(exp)]
57
  subprocess.run(cmd, cwd=str(RIFE_DIR), check=True)
58
 
59
- # Collect frames in natural sequence: img1.png → imgN.png
60
  frames = []
61
  i = 1
62
  while True:
@@ -66,9 +47,9 @@ def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
66
  frames.append(fp)
67
  i += 1
68
  if not frames:
69
- raise RuntimeError("No frames produced by RIFE.")
70
 
71
- # Assemble GIF
72
  images = [Image.open(p).convert("RGBA") for p in frames]
73
  duration_ms = max(1, int(1000 / max(1, fps)))
74
  gif_path = work_dir / "interpolation.gif"
@@ -82,7 +63,7 @@ def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
82
  disposal=2,
83
  )
84
 
85
- # Optional: cleanup RIFE output
86
  try:
87
  shutil.rmtree(OUTPUT_DIR)
88
  except Exception:
@@ -90,7 +71,8 @@ def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
90
 
91
  return str(gif_path)
92
 
93
- TITLE = "RIFE Interpolation → GIF (PyTorch, minimal)"
 
94
  with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
95
  gr.Markdown(f"# {TITLE}")
96
  with gr.Row():
@@ -99,7 +81,7 @@ with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
99
  img_b = gr.Image(type="pil", label="Image B")
100
  with gr.Column():
101
  fps = gr.Slider(6, 30, value=14, step=1, label="FPS")
102
- exp = gr.Slider(1, 4, value=2, step=1, label="Interpolation exponent (2^exp frames)")
103
  run = gr.Button("Interpolate", variant="primary")
104
  gif_out = gr.Image(type="filepath", label="Result GIF")
105
  run.click(interpolate, inputs=[img_a, img_b, fps, exp], outputs=[gif_out])
 
8
  import gradio as gr
9
  from PIL import Image
10
 
11
+ RIFE_DIR = Path("rife") # Local bundled repo
 
 
 
12
  OUTPUT_DIR = RIFE_DIR / "output"
 
 
13
  LOCK = threading.Lock()
 
14
 
15
  def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2):
 
 
16
  with LOCK:
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Normalize inputs
18
  a = a.convert("RGB")
19
  b = b.convert("RGB")
 
26
  a.save(p1, "PNG")
27
  b.save(p2, "PNG")
28
 
29
+ # Clean previous outputs
30
  if OUTPUT_DIR.exists():
31
  shutil.rmtree(OUTPUT_DIR)
32
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
33
 
34
+ # Run RIFE inference
35
+ cmd = ["python3", str(RIFE_DIR / "inference_img.py"), "--img", str(p1), str(p2)]
36
  if isinstance(exp, int) and exp >= 1:
37
  cmd += ["--exp", str(exp)]
38
  subprocess.run(cmd, cwd=str(RIFE_DIR), check=True)
39
 
40
+ # Collect interpolated frames
41
  frames = []
42
  i = 1
43
  while True:
 
47
  frames.append(fp)
48
  i += 1
49
  if not frames:
50
+ raise RuntimeError("No frames generated.")
51
 
52
+ # Build GIF
53
  images = [Image.open(p).convert("RGBA") for p in frames]
54
  duration_ms = max(1, int(1000 / max(1, fps)))
55
  gif_path = work_dir / "interpolation.gif"
 
63
  disposal=2,
64
  )
65
 
66
+ # Optional cleanup
67
  try:
68
  shutil.rmtree(OUTPUT_DIR)
69
  except Exception:
 
71
 
72
  return str(gif_path)
73
 
74
+ # Gradio UI
75
+ TITLE = "🔥 RIFE Interpolation Demo (PyTorch, Local)"
76
  with gr.Blocks(title=TITLE, analytics_enabled=False) as demo:
77
  gr.Markdown(f"# {TITLE}")
78
  with gr.Row():
 
81
  img_b = gr.Image(type="pil", label="Image B")
82
  with gr.Column():
83
  fps = gr.Slider(6, 30, value=14, step=1, label="FPS")
84
+ exp = gr.Slider(1, 4, value=2, step=1, label="Interpolation exponent")
85
  run = gr.Button("Interpolate", variant="primary")
86
  gif_out = gr.Image(type="filepath", label="Result GIF")
87
  run.click(interpolate, inputs=[img_a, img_b, fps, exp], outputs=[gif_out])