LanguageBind commited on
Commit
ebabd3d
·
verified ·
1 Parent(s): bfcee45

Update opensora/serve/gradio_web_server.py

Browse files
opensora/serve/gradio_web_server.py CHANGED
@@ -63,16 +63,16 @@ if __name__ == '__main__':
63
  device = torch.device('cuda:0')
64
 
65
  # Load model:
66
- transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16, cache_dir='cache_dir').to(device)
67
 
68
- vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
69
  vae.vae.enable_tiling()
70
  image_size = int(args.version.split('x')[1])
71
  latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
72
  vae.latent_size = latent_size
73
  transformer_model.force_images = args.force_images
74
- tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
75
- text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir",
76
  torch_dtype=torch.float16).to(device)
77
 
78
  # set eval mode
 
63
  device = torch.device('cuda:0')
64
 
65
  # Load model:
66
+ transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16).to(device)
67
 
68
+ vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae").to(device, dtype=torch.float16)
69
  vae.vae.enable_tiling()
70
  image_size = int(args.version.split('x')[1])
71
  latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
72
  vae.latent_size = latent_size
73
  transformer_model.force_images = args.force_images
74
+ tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name)
75
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name,
76
  torch_dtype=torch.float16).to(device)
77
 
78
  # set eval mode