DrmemoryFish commited on
Commit
817e05b
·
1 Parent(s): 2318eee

Upload webui.py

Browse files
Files changed (1) hide show
  1. webui.py +137 -0
webui.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import threading
3
+ import time
4
+ import importlib
5
+ import signal
6
+ import threading
7
+
8
+ from fastapi.middleware.gzip import GZipMiddleware
9
+
10
+ from modules.paths import script_path
11
+
12
+ from modules import devices, sd_samplers
13
+ import modules.codeformer_model as codeformer
14
+ import modules.extras
15
+ import modules.face_restoration
16
+ import modules.gfpgan_model as gfpgan
17
+ import modules.img2img
18
+
19
+ import modules.lowvram
20
+ import modules.paths
21
+ import modules.scripts
22
+ import modules.sd_hijack
23
+ import modules.sd_models
24
+ import modules.shared as shared
25
+ import modules.txt2img
26
+
27
+ import modules.ui
28
+ from modules import devices
29
+ from modules import modelloader
30
+ from modules.paths import script_path
31
+ from modules.shared import cmd_opts
32
+ import modules.hypernetworks.hypernetwork
33
+
34
+
35
+ queue_lock = threading.Lock()
36
+
37
+
38
+ def wrap_queued_call(func):
39
+ def f(*args, **kwargs):
40
+ with queue_lock:
41
+ res = func(*args, **kwargs)
42
+
43
+ return res
44
+
45
+ return f
46
+
47
+
48
+ def wrap_gradio_gpu_call(func, extra_outputs=None):
49
+ def f(*args, **kwargs):
50
+ devices.torch_gc()
51
+
52
+ shared.state.sampling_step = 0
53
+ shared.state.job_count = -1
54
+ shared.state.job_no = 0
55
+ shared.state.job_timestamp = shared.state.get_job_timestamp()
56
+ shared.state.current_latent = None
57
+ shared.state.current_image = None
58
+ shared.state.current_image_sampling_step = 0
59
+ shared.state.skipped = False
60
+ shared.state.interrupted = False
61
+ shared.state.textinfo = None
62
+
63
+ with queue_lock:
64
+ res = func(*args, **kwargs)
65
+
66
+ shared.state.job = ""
67
+ shared.state.job_count = 0
68
+
69
+ devices.torch_gc()
70
+
71
+ return res
72
+
73
+ return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
74
+
75
+ def initialize():
76
+ modelloader.cleanup_models()
77
+ modules.sd_models.setup_model()
78
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
79
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
80
+ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
81
+ modelloader.load_upscalers()
82
+
83
+ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
84
+
85
+ shared.sd_model = modules.sd_models.load_model()
86
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
87
+ shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
88
+ shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
89
+
90
+
91
+ def webui():
92
+ initialize()
93
+
94
+ # make the program just exit at ctrl+c without waiting for anything
95
+ def sigint_handler(sig, frame):
96
+ print(f'Interrupted with signal {sig} in {frame}')
97
+ os._exit(0)
98
+
99
+ signal.signal(signal.SIGINT, sigint_handler)
100
+
101
+ while 1:
102
+
103
+ demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
104
+
105
+ app, local_url, share_url = demo.launch(share=True)(
106
+ share=cmd_opts.share,
107
+ server_name="0.0.0.0" if cmd_opts.listen else None,
108
+ server_port=cmd_opts.port,
109
+ debug=cmd_opts.gradio_debug,
110
+ auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
111
+ inbrowser=cmd_opts.autolaunch,
112
+ prevent_thread_lock=True
113
+ )
114
+
115
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
116
+
117
+ while 1:
118
+ time.sleep(0.5)
119
+ if getattr(demo, 'do_restart', False):
120
+ time.sleep(0.5)
121
+ demo.close()
122
+ time.sleep(0.5)
123
+ break
124
+
125
+ sd_samplers.set_samplers()
126
+
127
+ print('Reloading Custom Scripts')
128
+ modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
129
+ print('Reloading modules: modules.ui')
130
+ importlib.reload(modules.ui)
131
+ print('Refreshing Model List')
132
+ modules.sd_models.list_models()
133
+ print('Restarting Gradio')
134
+
135
+
136
+ if __name__ == "__main__":
137
+ webui()