Boese0601 commited on
Commit
0936b4c
·
verified ·
1 Parent(s): 8207a46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -10,22 +10,25 @@ from image_datasets.dataset import image_resize
10
  args = OmegaConf.load("inference_configs/inference.yaml")
11
  device = torch.device("cuda")
12
  dtype = torch.bfloat16
 
 
13
  @spaces.GPU
14
  def generate(image: Image.Image, edit_prompt: str):
15
 
16
  from src.flux.xflux_pipeline import XFluxSampler
17
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- sampler = XFluxSampler(
20
- device = device,
21
- ip_loaded=False,
22
- spatial_condition=True,
23
- clip_image_processor=None,
24
- image_encoder=None,
25
- improj=None,
26
- share_position_embedding = True,
27
- )
28
-
29
  img = image_resize(image, 512)
30
  w, h = img.size
31
  img = img.resize(((w // 32) * 32, (h // 32) * 32))
 
10
  args = OmegaConf.load("inference_configs/inference.yaml")
11
  device = torch.device("cuda")
12
  dtype = torch.bfloat16
13
+ sampler = None
14
+
15
  @spaces.GPU
16
  def generate(image: Image.Image, edit_prompt: str):
17
 
18
  from src.flux.xflux_pipeline import XFluxSampler
19
 
20
+ global sampler
21
+ if sampler == None:
22
+ sampler = XFluxSampler(
23
+ device = device,
24
+ ip_loaded=False,
25
+ spatial_condition=True,
26
+ clip_image_processor=None,
27
+ image_encoder=None,
28
+ improj=None,
29
+ share_position_embedding = True,
30
+ )
31
 
 
 
 
 
 
 
 
 
 
 
32
  img = image_resize(image, 512)
33
  w, h = img.size
34
  img = img.resize(((w // 32) * 32, (h // 32) * 32))