Maximofn commited on
Commit
6ed4aaa
·
1 Parent(s): 49f3202

docs(src): :rocket: space

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -15,30 +15,32 @@ def generate_video(
15
  seed,
16
  num_inference_steps,
17
  guidance_scale,
18
- flow_shift,
19
- embedded_guidance_scale
20
  ):
21
  seed = None if seed == -1 else seed
22
  width, height = resolution.split("x")
23
  width, height = int(width), int(height)
24
 
25
  model = "hunyuanvideo-community/HunyuanVideo"
26
- transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
27
  model,
28
  subfolder="transformer",
29
  device_map="balanced",
30
- torch_dtype=torch.float16,
31
  )
32
- print(f"transformer_3bit device: {transformer_3bit.device}")
33
 
34
  # Cargar el pipeline
35
  pipeline = HunyuanVideoPipeline.from_pretrained(
36
  model,
37
- transformer=transformer_8bit,
38
- torch_dtype=torch.float16,
39
  device_map="balanced",
40
  )
41
  print(f"pipeline device: {pipeline.device}")
 
 
42
 
43
  # Generar el video usando el pipeline
44
  video = pipeline(
 
15
  seed,
16
  num_inference_steps,
17
  guidance_scale,
18
+ flow_shift, # TODO: change to flow_shift
19
+ embedded_guidance_scale # TODO: change to embedded_guidance_scale
20
  ):
21
  seed = None if seed == -1 else seed
22
  width, height = resolution.split("x")
23
  width, height = int(width), int(height)
24
 
25
  model = "hunyuanvideo-community/HunyuanVideo"
26
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
27
  model,
28
  subfolder="transformer",
29
  device_map="balanced",
30
+ torch_dtype=torch.float16, # TODO: change to bfloat16
31
  )
32
+ print(f"transformer device: {transformer.device}")
33
 
34
  # Cargar el pipeline
35
  pipeline = HunyuanVideoPipeline.from_pretrained(
36
  model,
37
+ transformer=transformer,
38
+ torch_dtype=torch.float16, # TODO: change to bfloat16
39
  device_map="balanced",
40
  )
41
  print(f"pipeline device: {pipeline.device}")
42
+ # TODO: pipeline.vae.enable_tiling()
43
+ # TODO: pipeline.to("cuda")
44
 
45
  # Generar el video usando el pipeline
46
  video = pipeline(