Spaces:
Sleeping
Sleeping
added the note for generation taking time
Browse files
app.py
CHANGED
@@ -19,7 +19,7 @@ from ldm.models.diffusion.plms import PLMSSampler
|
|
19 |
|
20 |
def load_model_from_config(config, ckpt, verbose=False):
|
21 |
print(f"Loading model from {ckpt}")
|
22 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
23 |
# pl_sd = torch.load(ckpt)#, map_location="cpu")
|
24 |
sd = pl_sd["state_dict"]
|
25 |
model = instantiate_from_config(config.model)
|
@@ -31,7 +31,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|
31 |
print("unexpected keys:")
|
32 |
print(u)
|
33 |
|
34 |
-
# model.cuda()
|
35 |
model.eval()
|
36 |
return model
|
37 |
|
@@ -50,8 +50,8 @@ def masking_embed(embedding, levels=1):
|
|
50 |
# LOAD MODEL GLOBALLY
|
51 |
ckpt_path = './model_files/fishes/epoch=000119.ckpt'
|
52 |
config_path = './model_files/fishes/2024-03-01T23-15-36-project.yaml'
|
53 |
-
config = OmegaConf.load(config_path)
|
54 |
-
model = load_model_from_config(config, ckpt_path)
|
55 |
|
56 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
57 |
model = model.to(device)
|
@@ -78,12 +78,7 @@ def generate_image(fish_name, masking_level_input,
|
|
78 |
return key
|
79 |
|
80 |
|
81 |
-
|
82 |
-
sampler = PLMSSampler(model)
|
83 |
-
else:
|
84 |
-
sampler = DDIMSampler(model)
|
85 |
-
|
86 |
-
|
87 |
|
88 |
prompt = opt.prompt
|
89 |
all_images = []
|
@@ -169,12 +164,6 @@ if __name__ == "__main__":
|
|
169 |
help="number of ddim sampling steps",
|
170 |
)
|
171 |
|
172 |
-
parser.add_argument(
|
173 |
-
"--plms",
|
174 |
-
action='store_true',
|
175 |
-
help="use plms sampling",
|
176 |
-
)
|
177 |
-
|
178 |
parser.add_argument(
|
179 |
"--ddim_eta",
|
180 |
type=float,
|
@@ -205,8 +194,6 @@ if __name__ == "__main__":
|
|
205 |
opt = parser.parse_args()
|
206 |
|
207 |
title = "🎞️ Phylo Diffusion - Generating Fish Images Tool"
|
208 |
-
description = "Write the Species name to generate an image for.\n For Trait Masking: Specify the Level information as well"
|
209 |
-
|
210 |
|
211 |
def load_example(prompt, level, option, components):
|
212 |
components['prompt_input'].value = prompt
|
@@ -214,14 +201,16 @@ if __name__ == "__main__":
|
|
214 |
|
215 |
def setup_interface():
|
216 |
with gr.Blocks() as demo:
|
217 |
-
|
|
|
218 |
gr.Markdown("### Write the Species name to generate a fish image")
|
219 |
-
gr.Markdown("###
|
|
|
220 |
|
221 |
with gr.Row():
|
222 |
with gr.Column():
|
223 |
-
gr.Markdown("## Generate Images Based on Prompts")
|
224 |
-
gr.Markdown("
|
225 |
prompt_input = gr.Textbox(label="Species Name")
|
226 |
|
227 |
# Radio button to select experiment type, with no default selection
|
@@ -248,7 +237,6 @@ if __name__ == "__main__":
|
|
248 |
gr.Markdown("## Select an example:")
|
249 |
examples = [
|
250 |
("Gambusia Affinis", "None", "", "Level 3"),
|
251 |
-
("Lepomis Auritus", "None", "", "Level 3"),
|
252 |
("Lepomis Auritus", "Level 3", "", "Level 3"),
|
253 |
("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")
|
254 |
]
|
|
|
19 |
|
20 |
def load_model_from_config(config, ckpt, verbose=False):
|
21 |
print(f"Loading model from {ckpt}")
|
22 |
+
pl_sd = torch.load(ckpt, map_location="cpu") # TODO: change for GPU resources
|
23 |
# pl_sd = torch.load(ckpt)#, map_location="cpu")
|
24 |
sd = pl_sd["state_dict"]
|
25 |
model = instantiate_from_config(config.model)
|
|
|
31 |
print("unexpected keys:")
|
32 |
print(u)
|
33 |
|
34 |
+
# model.cuda() # TODO: change for GPU resources
|
35 |
model.eval()
|
36 |
return model
|
37 |
|
|
|
50 |
# LOAD MODEL GLOBALLY
|
51 |
ckpt_path = './model_files/fishes/epoch=000119.ckpt'
|
52 |
config_path = './model_files/fishes/2024-03-01T23-15-36-project.yaml'
|
53 |
+
config = OmegaConf.load(config_path)
|
54 |
+
model = load_model_from_config(config, ckpt_path)
|
55 |
|
56 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
57 |
model = model.to(device)
|
|
|
78 |
return key
|
79 |
|
80 |
|
81 |
+
sampler = DDIMSampler(model)
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
prompt = opt.prompt
|
84 |
all_images = []
|
|
|
164 |
help="number of ddim sampling steps",
|
165 |
)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
parser.add_argument(
|
168 |
"--ddim_eta",
|
169 |
type=float,
|
|
|
194 |
opt = parser.parse_args()
|
195 |
|
196 |
title = "🎞️ Phylo Diffusion - Generating Fish Images Tool"
|
|
|
|
|
197 |
|
198 |
def load_example(prompt, level, option, components):
|
199 |
components['prompt_input'].value = prompt
|
|
|
201 |
|
202 |
def setup_interface():
|
203 |
with gr.Blocks() as demo:
|
204 |
+
|
205 |
+
gr.Markdown("# Phylo-Diffusion: Generating Fish Images Tool")
|
206 |
gr.Markdown("### Write the Species name to generate a fish image")
|
207 |
+
gr.Markdown("### 1. Trait Masking: Specify the Level information to mask")
|
208 |
+
gr.Markdown("### 2. Trait Swapping: Specify the species name to swap trait with and at what level")
|
209 |
|
210 |
with gr.Row():
|
211 |
with gr.Column():
|
212 |
+
# gr.Markdown("## Generate Images Based on Prompts")
|
213 |
+
gr.Markdown("**NOTE:** The demo is currently running on free CPU resources provided by Hugging Face, so it may take up to 10 minutes to generate an image. We're working on securing additional resources to speed up the process. Thank you for your patience!")
|
214 |
prompt_input = gr.Textbox(label="Species Name")
|
215 |
|
216 |
# Radio button to select experiment type, with no default selection
|
|
|
237 |
gr.Markdown("## Select an example:")
|
238 |
examples = [
|
239 |
("Gambusia Affinis", "None", "", "Level 3"),
|
|
|
240 |
("Lepomis Auritus", "Level 3", "", "Level 3"),
|
241 |
("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")
|
242 |
]
|