stazizov commited on
Commit
d15adcf
·
verified ·
1 Parent(s): a19af98

Update src/flux/xflux_pipeline.py

Browse files
Files changed (1) hide show
  1. src/flux/xflux_pipeline.py +3 -1
src/flux/xflux_pipeline.py CHANGED
@@ -29,7 +29,8 @@ from src.flux.util import (
29
  )
30
 
31
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
32
-
 
33
  class XFluxPipeline:
34
  def __init__(self, model_type, device, offload: bool = False):
35
  self.device = torch.device(device)
@@ -160,6 +161,7 @@ class XFluxPipeline:
160
  image_proj = self.improj(image_prompt_embeds)
161
  return image_proj
162
 
 
163
  def __call__(self,
164
  prompt: str,
165
  image_prompt: Image = None,
 
29
  )
30
 
31
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
32
+ import spaces
33
+ `
34
  class XFluxPipeline:
35
  def __init__(self, model_type, device, offload: bool = False):
36
  self.device = torch.device(device)
 
161
  image_proj = self.improj(image_prompt_embeds)
162
  return image_proj
163
 
164
+ @spaces.GPU
165
  def __call__(self,
166
  prompt: str,
167
  image_prompt: Image = None,