jadechoghari commited on
Commit
43b1627
1 Parent(s): 483dce0

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +49 -11
pipeline.py CHANGED
@@ -3,6 +3,7 @@ from .invert import Inverter
3
  from .generate import Generator
4
  from .utils import init_model, seed_everything, get_frame_ids
5
  import torch
 
6
 
7
  class VidToMePipeline(DiffusionPipeline):
8
  # def __init__(self, device="cuda", sd_version="2.1", float_precision="fp16", height=512, width=512):
@@ -57,16 +58,52 @@ class VidToMePipeline(DiffusionPipeline):
57
  config['generation']['output_path'], frame_ids=frame_ids)
58
  print(f"Output generated at: {config['generation']['output_path']}")
59
 
60
- def _build_config(self, video_path, video_prompt, edit_prompt, control_type,
61
- n_timesteps, guidance_scale, negative_prompt, frame_range,
62
- use_lora, seed, local_merge_ratio, global_merge_ratio):
63
- # constructing config dictionary from user prompts
64
- config = {
65
- 'sd_version': self.sd_version,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  'input_path': video_path,
67
  'work_dir': "outputs/",
68
- 'height': self.height,
69
- 'width': self.width,
70
  'inversion': {
71
  'prompt': video_prompt or "Default video prompt.",
72
  'save_path': "outputs/latents",
@@ -88,9 +125,10 @@ class VidToMePipeline(DiffusionPipeline):
88
  },
89
  'seed': seed,
90
  'device': "cuda",
91
- 'float_precision': self.float_precision
92
- }
93
- return config
 
94
 
95
  # # Sample usage
96
  # pipeline = VidToMePipeline(device="cuda", sd_version="2.1", float_precision="fp16")
 
3
  from .generate import Generator
4
  from .utils import init_model, seed_everything, get_frame_ids
5
  import torch
6
+ from omegaconf import OmegaConf
7
 
8
  class VidToMePipeline(DiffusionPipeline):
9
  # def __init__(self, device="cuda", sd_version="2.1", float_precision="fp16", height=512, width=512):
 
58
  config['generation']['output_path'], frame_ids=frame_ids)
59
  print(f"Output generated at: {config['generation']['output_path']}")
60
 
61
+ # def _build_config(self, video_path, video_prompt, edit_prompt, control_type,
62
+ # n_timesteps, guidance_scale, negative_prompt, frame_range,
63
+ # use_lora, seed, local_merge_ratio, global_merge_ratio):
64
+ # # constructing config dictionary from user prompts
65
+ # config = {
66
+ # 'sd_version': self.sd_version,
67
+ # 'input_path': video_path,
68
+ # 'work_dir': "outputs/",
69
+ # 'height': self.height,
70
+ # 'width': self.width,
71
+ # 'inversion': {
72
+ # 'prompt': video_prompt or "Default video prompt.",
73
+ # 'save_path': "outputs/latents",
74
+ # 'steps': 50,
75
+ # 'save_intermediate': False
76
+ # },
77
+ # 'generation': {
78
+ # 'control': control_type,
79
+ # 'guidance_scale': guidance_scale,
80
+ # 'n_timesteps': n_timesteps,
81
+ # 'negative_prompt': negative_prompt,
82
+ # 'prompt': edit_prompt or "Default edit prompt.",
83
+ # 'latents_path': "outputs/latents",
84
+ # 'output_path': "outputs/final",
85
+ # 'frame_range': frame_range or [0, 32],
86
+ # 'use_lora': use_lora,
87
+ # 'local_merge_ratio': local_merge_ratio,
88
+ # 'global_merge_ratio': global_merge_ratio
89
+ # },
90
+ # 'seed': seed,
91
+ # 'device': "cuda",
92
+ # 'float_precision': self.float_precision
93
+ # }
94
+ # return config
95
+ from omegaconf import OmegaConf
96
+
97
+ def build_config(video_path, video_prompt, edit_prompt, control_type,
98
+ n_timesteps, guidance_scale, negative_prompt, frame_range,
99
+ use_lora, seed, local_merge_ratio, global_merge_ratio):
100
+ # Create a config using OmegaConf
101
+ config = OmegaConf.create({
102
+ 'sd_version': '1.5',
103
  'input_path': video_path,
104
  'work_dir': "outputs/",
105
+ 'height': 512,
106
+ 'width': 512,
107
  'inversion': {
108
  'prompt': video_prompt or "Default video prompt.",
109
  'save_path': "outputs/latents",
 
125
  },
126
  'seed': seed,
127
  'device': "cuda",
128
+ 'float_precision': "fp16"
129
+ })
130
+
131
+ return config
132
 
133
  # # Sample usage
134
  # pipeline = VidToMePipeline(device="cuda", sd_version="2.1", float_precision="fp16")