Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -35,12 +35,12 @@ import torch
|
|
35 |
# print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
|
36 |
# print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
|
37 |
|
38 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
-
name = 'flux-dev'
|
40 |
-
ae = load_ae(name, device)
|
41 |
-
t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
|
42 |
-
clip = load_clip(device)
|
43 |
-
model = load_flow_model(name, device=device)
|
44 |
print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
|
45 |
print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
|
46 |
print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
|
@@ -59,17 +59,17 @@ class SamplingOptions:
|
|
59 |
|
60 |
|
61 |
|
62 |
-
offload = False
|
63 |
-
name = "flux-dev"
|
64 |
-
is_schnell = False
|
65 |
-
feature_path = 'feature'
|
66 |
-
output_dir = 'result'
|
67 |
-
add_sampling_metadata = True
|
68 |
|
69 |
|
70 |
|
71 |
@torch.inference_mode()
|
72 |
-
def encode(init_image, torch_device
|
73 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
74 |
init_image = init_image.unsqueeze(0)
|
75 |
init_image = init_image.to(torch_device)
|
@@ -98,7 +98,7 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
|
|
98 |
width, height = init_image.shape[0], init_image.shape[1]
|
99 |
|
100 |
|
101 |
-
init_image = encode(init_image, device
|
102 |
|
103 |
print(init_image.shape)
|
104 |
|
@@ -169,7 +169,6 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
|
|
169 |
else:
|
170 |
idx = 0
|
171 |
|
172 |
-
ae = ae.cuda()
|
173 |
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
174 |
x = ae.decode(x)
|
175 |
|
|
|
35 |
# print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
|
36 |
# print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
|
37 |
|
38 |
+
global device = "cuda" if torch.cuda.is_available() else "cpu"
|
39 |
+
global name = 'flux-dev'
|
40 |
+
global ae = load_ae(name, device)
|
41 |
+
global t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
|
42 |
+
global clip = load_clip(device)
|
43 |
+
global model = load_flow_model(name, device=device)
|
44 |
print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
|
45 |
print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
|
46 |
print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
|
|
|
59 |
|
60 |
|
61 |
|
62 |
+
global offload = False
|
63 |
+
global name = "flux-dev"
|
64 |
+
global is_schnell = False
|
65 |
+
global feature_path = 'feature'
|
66 |
+
global output_dir = 'result'
|
67 |
+
global add_sampling_metadata = True
|
68 |
|
69 |
|
70 |
|
71 |
@torch.inference_mode()
|
72 |
+
def encode(init_image, torch_device):
|
73 |
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
|
74 |
init_image = init_image.unsqueeze(0)
|
75 |
init_image = init_image.to(torch_device)
|
|
|
98 |
width, height = init_image.shape[0], init_image.shape[1]
|
99 |
|
100 |
|
101 |
+
init_image = encode(init_image, device)
|
102 |
|
103 |
print(init_image.shape)
|
104 |
|
|
|
169 |
else:
|
170 |
idx = 0
|
171 |
|
|
|
172 |
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
173 |
x = ae.decode(x)
|
174 |
|