wjs0725 commited on
Commit
c633291
·
verified ·
1 Parent(s): d73cf56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -35,12 +35,12 @@ import torch
35
  # print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
36
  # print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
37
 
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
- name = 'flux-dev'
40
- ae = load_ae(name, device)
41
- t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
42
- clip = load_clip(device)
43
- model = load_flow_model(name, device=device)
44
  print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
45
  print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
46
  print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
@@ -59,17 +59,17 @@ class SamplingOptions:
59
 
60
 
61
 
62
- offload = False
63
- name = "flux-dev"
64
- is_schnell = False
65
- feature_path = 'feature'
66
- output_dir = 'result'
67
- add_sampling_metadata = True
68
 
69
 
70
 
71
  @torch.inference_mode()
72
- def encode(init_image, torch_device, ae):
73
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
74
  init_image = init_image.unsqueeze(0)
75
  init_image = init_image.to(torch_device)
@@ -98,7 +98,7 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
98
  width, height = init_image.shape[0], init_image.shape[1]
99
 
100
 
101
- init_image = encode(init_image, device, ae)
102
 
103
  print(init_image.shape)
104
 
@@ -169,7 +169,6 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
169
  else:
170
  idx = 0
171
 
172
- ae = ae.cuda()
173
  with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
174
  x = ae.decode(x)
175
 
 
35
  # print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
36
  # print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
37
 
38
+ global device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ global name = 'flux-dev'
40
+ global ae = load_ae(name, device)
41
+ global t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
42
+ global clip = load_clip(device)
43
+ global model = load_flow_model(name, device=device)
44
  print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
45
  print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
46
  print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
 
59
 
60
 
61
 
62
+ global offload = False
63
+ global name = "flux-dev"
64
+ global is_schnell = False
65
+ global feature_path = 'feature'
66
+ global output_dir = 'result'
67
+ global add_sampling_metadata = True
68
 
69
 
70
 
71
  @torch.inference_mode()
72
+ def encode(init_image, torch_device):
73
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
74
  init_image = init_image.unsqueeze(0)
75
  init_image = init_image.to(torch_device)
 
98
  width, height = init_image.shape[0], init_image.shape[1]
99
 
100
 
101
+ init_image = encode(init_image, device)
102
 
103
  print(init_image.shape)
104
 
 
169
  else:
170
  idx = 0
171
 
 
172
  with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
173
  x = ae.decode(x)
174