Update src/flux/generate.py
Browse files- src/flux/generate.py +2 -2
src/flux/generate.py
CHANGED
@@ -51,7 +51,7 @@ def prepare_params(
|
|
51 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
52 |
height: Optional[int] = 512,
|
53 |
width: Optional[int] = 512,
|
54 |
-
num_inference_steps: int =
|
55 |
timesteps: List[int] = None,
|
56 |
guidance_scale: float = 3.5,
|
57 |
num_images_per_prompt: Optional[int] = 1,
|
@@ -708,7 +708,7 @@ def generate_from_test_sample(
|
|
708 |
return delta_emb, delta_emb_pblock, delta_emb_mask, \
|
709 |
text_cond_mask, delta_start_ends, condition_latents, condition_ids
|
710 |
|
711 |
-
num_inference_steps =
|
712 |
num_channels_latents = pipe.transformer.config.in_channels // 4
|
713 |
|
714 |
# set timesteps
|
|
|
51 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
52 |
height: Optional[int] = 512,
|
53 |
width: Optional[int] = 512,
|
54 |
+
num_inference_steps: int = 4,
|
55 |
timesteps: List[int] = None,
|
56 |
guidance_scale: float = 3.5,
|
57 |
num_images_per_prompt: Optional[int] = 1,
|
|
|
708 |
return delta_emb, delta_emb_pblock, delta_emb_mask, \
|
709 |
text_cond_mask, delta_start_ends, condition_latents, condition_ids
|
710 |
|
711 |
+
num_inference_steps = 4 # FIXME: harcoded here
|
712 |
num_channels_latents = pipe.transformer.config.in_channels // 4
|
713 |
|
714 |
# set timesteps
|