stazizov commited on
Commit
8e44d92
·
verified ·
1 Parent(s): 7a5edc2

Update src/flux/xflux_pipeline.py

Browse files
Files changed (1) hide show
  1. src/flux/xflux_pipeline.py +6 -3
src/flux/xflux_pipeline.py CHANGED
@@ -31,8 +31,6 @@ from src.flux.util import (
31
  )
32
 
33
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
34
- import spaces
35
-
36
 
37
  class XFluxPipeline:
38
  def __init__(self, model_type, device, offload: bool = False):
@@ -57,6 +55,12 @@ class XFluxPipeline:
57
  self.controlnet_loaded = False
58
  self.ip_loaded = False
59
 
 
 
 
 
 
 
60
  def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
61
  self.model.to(self.device)
62
 
@@ -221,7 +225,6 @@ class XFluxPipeline:
221
  neg_ip_scale=neg_ip_scale,
222
  )
223
 
224
- @spaces.GPU()
225
  def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
226
  num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
227
  neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
 
31
  )
32
 
33
  from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
 
 
34
 
35
  class XFluxPipeline:
36
  def __init__(self, model_type, device, offload: bool = False):
 
55
  self.controlnet_loaded = False
56
  self.ip_loaded = False
57
 
58
+ def to(*args, **args):
59
+ self.model.to(*args, **kwargs)
60
+ self.clip.to(*args, **kwargs)
61
+ self.t5.to(*args, **kwargs)
62
+ self.ae.to(*args, **kwargs)
63
+
64
  def set_ip(self, local_path: str = None, repo_id = None, name: str = None):
65
  self.model.to(self.device)
66
 
 
225
  neg_ip_scale=neg_ip_scale,
226
  )
227
 
 
228
  def gradio_generate(self, prompt, image_prompt, controlnet_image, width, height, guidance,
229
  num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
230
  neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,