xywwww commited on
Commit
e65f030
·
verified ·
1 Parent(s): ebae053

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from annotator.util import resize_image, HWC3
4
+ from annotator.canny import CannyDetector
5
+ from cldm.model import create_model, load_state_dict
6
+ from cldm.ddim_hacked import DDIMSampler
7
+
8
+ # Initialize the model and other components
9
+ apply_canny = CannyDetector()
10
+ model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
11
+ model.load_state_dict(load_state_dict('xywwww/scene_diffusion/checkpoints/epoch=25-step=112553.ckpt', location='cuda'), strict=False)
12
+ model = model.cuda()
13
+ ddim_sampler = DDIMSampler(model)
14
+
15
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
16
+ with torch.no_grad():
17
+ img = resize_image(HWC3(input_image), image_resolution)
18
+ H, W, C = img.shape
19
+ # detected_map = apply_canny(img, low_threshold, high_threshold)
20
+ # detected_map = HWC3(detected_map)
21
+ # Add the rest of the processing logic here
22
+
23
+ def create_demo(process):
24
+ with gr.Blocks() as demo:
25
+ with gr.Row():
26
+ with gr.Column():
27
+ input_image = gr.Image()
28
+ prompt = gr.Textbox(label="Prompt", submit_btn=True)
29
+ a_prompt = gr.Textbox(label="Additional Prompt")
30
+ n_prompt = gr.Textbox(label="Negative Prompt")
31
+ with gr.Accordion("Advanced options", open=False):
32
+ num_samples = gr.Slider(label="Number of images", minimum=1, maximum=10, value=1, step=1)
33
+ image_resolution = gr.Slider(label="Image resolution", minimum=256, maximum=1024, value=512, step=256)
34
+ ddim_steps = gr.Slider(label="DDIM Steps", minimum=1, maximum=100, value=50, step=1)
35
+ guess_mode = gr.Checkbox(label="Guess Mode")
36
+ strength = gr.Slider(label="Strength", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
37
+ scale = gr.Slider(label="Scale", minimum=0.1, maximum=30.0, value=10.0, step=0.1)
38
+ seed = gr.Slider(label="Seed", minimum=0, maximum=10000, value=42, step=1)
39
+ eta = gr.Slider(label="ETA", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
40
+ low_threshold = gr.Slider(label="Canny Low Threshold", minimum=1, maximum=255, value=100, step=1)
41
+ high_threshold = gr.Slider(label="Canny High Threshold", minimum=1, maximum=255, value=200, step=1)
42
+ submit = gr.Button("Generate")
43
+ with gr.Column():
44
+ output_image = gr.Image()
45
+ submit.click(fn=process, inputs=[input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold], outputs=output_image)
46
+ return demo
47
+
48
+ demo = create_demo(process)
49
+ demo.launch()