linoyts HF Staff commited on
Commit
afe56ea
·
verified ·
1 Parent(s): 56732fb

Update kontext_pipeline.py

Browse files
Files changed (1) hide show
  1. kontext_pipeline.py +45 -19
kontext_pipeline.py CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
@@ -13,12 +27,7 @@ from transformers import (
13
  )
14
 
15
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
16
- from diffusers.loaders import (
17
- FluxIPAdapterMixin,
18
- FluxLoraLoaderMixin,
19
- FromSingleFileMixin,
20
- TextualInversionLoaderMixin,
21
- )
22
  from diffusers.models import AutoencoderKL, FluxTransformer2DModel
23
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
24
  from diffusers.utils import (
@@ -29,14 +38,11 @@ from diffusers.utils import (
29
  scale_lora_layers,
30
  unscale_lora_layers,
31
  )
32
-
33
  from diffusers.utils.torch_utils import randn_tensor
34
- from diffusers import DiffusionPipeline
35
-
36
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
37
 
38
 
39
-
40
  if is_torch_xla_available():
41
  import torch_xla.core.xla_model as xm
42
 
@@ -50,11 +56,27 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
  EXAMPLE_DOC_STRING = """
51
  Examples:
52
  ```py
53
- # TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ```
55
  """
56
 
57
-
58
  PREFERRED_KONTEXT_RESOLUTIONS = [
59
  (672, 1568),
60
  (688, 1504),
@@ -718,6 +740,7 @@ class FluxKontextPipeline(
718
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
719
  max_sequence_length: int = 512,
720
  max_area: int = 1024**2,
 
721
  ):
722
  r"""
723
  Function invoked when calling the pipeline for generation.
@@ -915,13 +938,16 @@ class FluxKontextPipeline(
915
 
916
  # 3. Preprocess image
917
  if not torch.is_tensor(image) or image.size(1) == self.latent_channels:
918
- image_width, image_height = self.image_processor.get_default_height_width(image)
 
 
 
919
  aspect_ratio = image_width / image_height
920
-
921
- # Kontext is trained on specific resolutions, using one of them is recommended
922
- _, image_width, image_height = min(
923
- (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
924
- )
925
  image_width = image_width // multiple_of * multiple_of
926
  image_height = image_height // multiple_of * multiple_of
927
  image = self.image_processor.resize(image, image_height, image_width)
@@ -1085,4 +1111,4 @@ class FluxKontextPipeline(
1085
  if not return_dict:
1086
  return (image,)
1087
 
1088
- return FluxPipelineOutput(images=image)
 
1
+ # Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
  import inspect
16
  from typing import Any, Callable, Dict, List, Optional, Union
17
 
 
27
  )
28
 
29
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
 
 
 
 
 
31
  from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
  from diffusers.utils import (
 
38
  scale_lora_layers,
39
  unscale_lora_layers,
40
  )
 
41
  from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
 
43
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
44
 
45
 
 
46
  if is_torch_xla_available():
47
  import torch_xla.core.xla_model as xm
48
 
 
56
  EXAMPLE_DOC_STRING = """
57
  Examples:
58
  ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxKontextPipeline
61
+ >>> from diffusers.utils import load_image
62
+
63
+ >>> pipe = FluxKontextPipeline.from_pretrained(
64
+ ... "black-forest-labs/FLUX.1-kontext", transformer=transformer, torch_dtype=torch.bfloat16
65
+ ... )
66
+ >>> pipe.to("cuda")
67
+
68
+ >>> image = load_image("inputs/yarn-art-pikachu.png").convert("RGB")
69
+ >>> prompt = "Make Pikachu hold a sign that says 'Hugging Face is awesome', yarn art style, detailed, vibrant colors"
70
+ >>> image = pipe(
71
+ ... image=image,
72
+ ... prompt=prompt,
73
+ ... guidance_scale=2.5,
74
+ ... generator=torch.Generator().manual_seed(42),
75
+ ... ).images[0]
76
+ >>> image.save("output.png")
77
  ```
78
  """
79
 
 
80
  PREFERRED_KONTEXT_RESOLUTIONS = [
81
  (672, 1568),
82
  (688, 1504),
 
740
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
741
  max_sequence_length: int = 512,
742
  max_area: int = 1024**2,
743
+ _auto_resize: bool = True,
744
  ):
745
  r"""
746
  Function invoked when calling the pipeline for generation.
 
938
 
939
  # 3. Preprocess image
940
  if not torch.is_tensor(image) or image.size(1) == self.latent_channels:
941
+ if isinstance(image, list):
942
+ image_width, image_height = self.image_processor.get_default_height_width(image[0])
943
+ else:
944
+ image_width, image_height = self.image_processor.get_default_height_width(image)
945
  aspect_ratio = image_width / image_height
946
+ if _auto_resize:
947
+ # Kontext is trained on specific resolutions, using one of them is recommended
948
+ _, image_width, image_height = min(
949
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
950
+ )
951
  image_width = image_width // multiple_of * multiple_of
952
  image_height = image_height // multiple_of * multiple_of
953
  image = self.image_processor.resize(image, image_height, image_width)
 
1111
  if not return_dict:
1112
  return (image,)
1113
 
1114
+ return FluxPipelineOutput(images=image)