YuxueYang commited on
Commit
f01a554
·
1 Parent(s): b9f6606
Files changed (1) hide show
  1. app.py +21 -6
app.py CHANGED
@@ -46,6 +46,23 @@ snapshot_download(
46
  TEXT_ENCODER = FrozenOpenCLIPEmbedder().eval()
47
  IMAGE_ENCODER = FrozenOpenCLIPImageEmbedderV2().eval()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  TRANSFORMS = transforms.Compose([
50
  transforms.Resize(min(HEIGHT, WIDTH)),
51
  transforms.CenterCrop((HEIGHT, WIDTH)),
@@ -72,17 +89,15 @@ def set_model(pretrained_model_path):
72
  vae_dualref = AutoencoderKL_Dualref.from_pretrained(pretrained_model_path, subfolder="vae_dualref").eval()
73
  unet = UNetModel.from_pretrained(pretrained_model_path, subfolder="unet").eval()
74
  layer_controlnet = LayerControlNet.from_pretrained(pretrained_model_path, subfolder="layer_controlnet").eval()
75
-
76
- PIPELINE = AnimationPipeline(
77
  vae=vae, vae_dualref=vae_dualref, text_encoder=TEXT_ENCODER, image_encoder=IMAGE_ENCODER, image_projector=image_projector,
78
  unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler
79
- ).to(device=DEVICE, dtype=WEIGHT_DTYPE)
 
80
  if "Interp" or "Mix" in pretrained_model_path:
81
  PIPELINE.vae_dualref.decoder.to(dtype=torch.float32)
82
  return pretrained_model_path
83
 
84
- set_model("checkpoints/LayerAnimate-Mix")
85
-
86
  def upload_image(image):
87
  image = TRANSFORMS(image)
88
  return image
@@ -92,7 +107,7 @@ def run(input_image, input_image_end, pretrained_model_path, seed,
92
  prompt, n_prompt, num_inference_steps, guidance_scale,
93
  *layer_args):
94
  generator = set_seed(seed, DEVICE)
95
- global layer_tracking_points, PIPELINE
96
  args_layer_tracking_points = [layer_tracking_points[i].value for i in range(LAYER_CAPACITY)]
97
 
98
  args_layer_masks = layer_args[:LAYER_CAPACITY]
 
46
  TEXT_ENCODER = FrozenOpenCLIPEmbedder().eval()
47
  IMAGE_ENCODER = FrozenOpenCLIPImageEmbedderV2().eval()
48
 
49
+ default_path = "checkpoints/LayerAnimate-Mix"
50
+ scheduler = DDIMScheduler.from_pretrained(default_path, subfolder="scheduler")
51
+ image_projector = Resampler.from_pretrained(default_path, subfolder="image_projector").eval()
52
+ vae, vae_dualref = None, None
53
+ if "I2V" or "Mix" in default_path:
54
+ vae = AutoencoderKL.from_pretrained(default_path, subfolder="vae").eval()
55
+ if "Interp" or "Mix" in default_path:
56
+ vae_dualref = AutoencoderKL_Dualref.from_pretrained(default_path, subfolder="vae_dualref").eval()
57
+ unet = UNetModel.from_pretrained(default_path, subfolder="unet").eval()
58
+ layer_controlnet = LayerControlNet.from_pretrained(default_path, subfolder="layer_controlnet").eval()
59
+ PIPELINE = AnimationPipeline(
60
+ vae=vae, vae_dualref=vae_dualref, text_encoder=TEXT_ENCODER, image_encoder=IMAGE_ENCODER, image_projector=image_projector,
61
+ unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler
62
+ ).to(device=DEVICE, dtype=WEIGHT_DTYPE)
63
+ if "Interp" or "Mix" in default_path:
64
+ PIPELINE.vae_dualref.decoder.to(dtype=torch.float32)
65
+
66
  TRANSFORMS = transforms.Compose([
67
  transforms.Resize(min(HEIGHT, WIDTH)),
68
  transforms.CenterCrop((HEIGHT, WIDTH)),
 
89
  vae_dualref = AutoencoderKL_Dualref.from_pretrained(pretrained_model_path, subfolder="vae_dualref").eval()
90
  unet = UNetModel.from_pretrained(pretrained_model_path, subfolder="unet").eval()
91
  layer_controlnet = LayerControlNet.from_pretrained(pretrained_model_path, subfolder="layer_controlnet").eval()
92
+ PIPELINE.update(
 
93
  vae=vae, vae_dualref=vae_dualref, text_encoder=TEXT_ENCODER, image_encoder=IMAGE_ENCODER, image_projector=image_projector,
94
  unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler
95
+ )
96
+ PIPELINE.to(device=DEVICE, dtype=WEIGHT_DTYPE)
97
  if "Interp" or "Mix" in pretrained_model_path:
98
  PIPELINE.vae_dualref.decoder.to(dtype=torch.float32)
99
  return pretrained_model_path
100
 
 
 
101
  def upload_image(image):
102
  image = TRANSFORMS(image)
103
  return image
 
107
  prompt, n_prompt, num_inference_steps, guidance_scale,
108
  *layer_args):
109
  generator = set_seed(seed, DEVICE)
110
+ global layer_tracking_points
111
  args_layer_tracking_points = [layer_tracking_points[i].value for i in range(LAYER_CAPACITY)]
112
 
113
  args_layer_masks = layer_args[:LAYER_CAPACITY]