Spanicin commited on
Commit
c8f5641
·
verified ·
1 Parent(s): bd26dca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +303 -195
app.py CHANGED
@@ -1,3 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
  import random
3
  import warnings
@@ -5,7 +268,6 @@ import gradio as gr
5
  import os
6
  import shutil
7
  import subprocess
8
- import spaces
9
  import torch
10
  import numpy as np
11
  from diffusers import FluxControlNetModel
@@ -14,18 +276,12 @@ from PIL import Image
14
  from huggingface_hub import snapshot_download, login
15
  import io
16
  import base64
17
- from flask import Flask, request, jsonify
18
- from concurrent.futures import ThreadPoolExecutor
19
- from flask_cors import CORS
20
  import threading
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
- app = Flask(__name__)
27
- CORS(app)
28
-
29
  # Function to check disk usage
30
  def check_disk_space():
31
  result = subprocess.run(['df', '-h'], capture_output=True, text=True)
@@ -47,10 +303,7 @@ check_disk_space()
47
  clear_huggingface_cache()
48
 
49
  # Add config to store base64 images
50
- app.config['image_outputs'] = {}
51
-
52
- # ThreadPoolExecutor for managing image processing threads
53
- executor = ThreadPoolExecutor()
54
 
55
  # Determine the device (GPU or CPU)
56
  if torch.cuda.is_available():
@@ -68,38 +321,32 @@ if huggingface_token:
68
  else:
69
  logger.warning("Hugging Face token not found in environment variables.")
70
 
71
- logger.info("Hugging Face token: %s", huggingface_token)
72
-
73
  # Download model using snapshot_download
74
-
75
  model_path = snapshot_download(
76
- repo_id="black-forest-labs/FLUX.1-dev",
77
- repo_type="model",
78
- ignore_patterns=["*.md", "*..gitattributes"],
79
- local_dir="FLUX.1-dev",
80
- token=huggingface_token)
 
81
  logger.info("Model downloaded to: %s", model_path)
82
 
83
  # Load pipeline
84
  logger.info('Loading ControlNet model.')
85
-
86
  controlnet = FluxControlNetModel.from_pretrained(
87
- "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
88
- ).to(device)
89
  logger.info("ControlNet model loaded successfully.")
90
 
91
  logger.info('Loading pipeline.')
92
-
93
  pipe = FluxControlNetPipeline.from_pretrained(
94
- model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
95
- ).to(device)
96
  logger.info("Pipeline loaded successfully.")
97
 
98
  MAX_SEED = 1000000
99
  MAX_PIXEL_BUDGET = 1024 * 1024
100
 
101
-
102
- @spaces.GPU
103
  def process_input(input_image, upscale_factor):
104
  w, h = input_image.size
105
  aspect_ratio = w / h
@@ -123,9 +370,8 @@ def process_input(input_image, upscale_factor):
123
 
124
  return input_image.resize((w, h)), was_resized
125
 
126
- @spaces.GPU
127
- def run_inference(process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale):
128
- logger.info("Processing inference for process_id: %s", process_id)
129
  input_image, was_resized = process_input(input_image, upscale_factor)
130
 
131
  # Rescale image for ControlNet processing
@@ -136,7 +382,7 @@ def run_inference(process_id, input_image, upscale_factor, seed, num_inference_s
136
  generator = torch.Generator().manual_seed(seed)
137
 
138
  # Perform inference using the pipeline
139
- logger.info("Running pipeline for process_id: %s", process_id)
140
  image = pipe(
141
  prompt="",
142
  control_image=control_image,
@@ -158,176 +404,38 @@ def run_inference(process_id, input_image, upscale_factor, seed, num_inference_s
158
  image.save(buffered, format="JPEG")
159
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
160
 
161
- # Store the result in the shared dictionary
162
- app.config['image_outputs'][process_id] = image_base64
163
- logger.info("Inference completed for process_id: %s", process_id)
164
-
165
- @app.route('/infer', methods=['POST'])
166
- def infer():
167
- # Check if the file was provided in the form-data
168
- if 'input_image' not in request.files:
169
- logger.error("No image file provided in request.")
170
- return jsonify({
171
- "status": "error",
172
- "message": "No input_image file provided"
173
- }), 400
174
-
175
- # Get the uploaded image file from the request
176
- file = request.files['input_image']
177
-
178
- # Check if a file was uploaded
179
- if file.filename == '':
180
- logger.error("No selected file in form-data.")
181
- return jsonify({
182
- "status": "error",
183
- "message": "No selected file"
184
- }), 400
185
-
186
- # Convert the image to Base64 for internal processing
187
- input_image = Image.open(file)
188
- buffered = io.BytesIO()
189
- input_image.save(buffered, format="JPEG")
190
-
191
- # Retrieve additional parameters from the request (if any)
192
- seed = request.form.get("seed", 42, type=int)
193
- randomize_seed = request.form.get("randomize_seed", 'true').lower() == 'true'
194
- num_inference_steps = request.form.get("num_inference_steps", 28, type=int)
195
- upscale_factor = request.form.get("upscale_factor", 4, type=int)
196
- controlnet_conditioning_scale = request.form.get("controlnet_conditioning_scale", 0.6, type=float)
197
 
198
- # Randomize seed if specified
 
199
  if randomize_seed:
200
  seed = random.randint(0, MAX_SEED)
201
  logger.info("Seed randomized to: %d", seed)
202
 
203
- # Create a unique process ID for this request
204
- process_id = str(random.randint(1000, 9999))
205
- logger.info("Process started with process_id: %s", process_id)
206
-
207
- # Set the status to 'in_progress'
208
- app.config['image_outputs'][process_id] = None
209
-
210
- # Run the inference in a separate thread
211
- executor.submit(run_inference, process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale)
212
-
213
- # Return the process ID
214
- return jsonify({
215
- "process_id": process_id,
216
- "message": "Processing started"
217
- })
218
-
219
-
220
- # Modify status endpoint to receive process_id in request body
221
- @app.route('/status', methods=['GET'])
222
- def status():
223
- # Get the process_id from the query parameters
224
- process_id = request.args.get('process_id')
225
-
226
- # Check if process_id was provided
227
- if not process_id:
228
- logger.error("Process ID not provided in request.")
229
- return jsonify({
230
- "status": "error",
231
- "message": "Process ID is required"
232
- }), 400
233
-
234
- # Check if the process_id exists in the dictionary
235
- if process_id not in app.config['image_outputs']:
236
- logger.error("Invalid process ID: %s", process_id)
237
- return jsonify({
238
- "status": "error",
239
- "message": "Invalid process ID"
240
- }), 404
241
-
242
- # Check the status of the image processing
243
- image_base64 = app.config['image_outputs'][process_id]
244
- if image_base64 is None:
245
- logger.info("Process ID %s is still in progress.", process_id)
246
- return jsonify({
247
- "status": "in_progress"
248
- })
249
- else:
250
- logger.info("Process ID %s completed successfully.", process_id)
251
- return jsonify({
252
- "status": "completed",
253
- "output_image": image_base64
254
- })
255
-
256
-
257
- # Gradio Blocks setup
258
- css = """
259
- #col-container {
260
- max-width: 800px;
261
- }
262
- """
263
-
264
- with gr.Blocks(css=css) as demo:
265
- gr.Markdown(
266
- """
267
- # ⚡ Flux.1-dev Upscaler ControlNet ⚡
268
- This is an interactive demo of [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) taking as input a low resolution image to generate a high resolution image.
269
- Currently running on {device}.
270
- *Note*: Even though the model can handle higher resolution images, due to GPU memory constraints, this demo was limited to a generated output not exceeding a pixel budget of 1024x1024. If the requested size exceeds that limit, the input will be first resized keeping the aspect ratio such that the output of the controlNet model does not exceed the allocated pixel budget. The output is then resized to the targeted shape using a simple resizing. This may explain some artifacts for high resolution input. To address this, run the demo locally or consider implementing a tiling strategy. Happy upscaling! 🚀
271
- """.format(device=device)
272
- )
273
-
274
- with gr.Row():
275
- run_button = gr.Button(value="Run")
276
-
277
- with gr.Row():
278
- with gr.Column(scale=4):
279
- input_im = gr.Image(label="Input Image", type="pil")
280
- with gr.Column(scale=1):
281
- num_inference_steps = gr.Slider(
282
- label="Number of Inference Steps",
283
- minimum=8,
284
- maximum=50,
285
- step=1,
286
- value=28,
287
- )
288
- upscale_factor = gr.Slider(
289
- label="Upscale Factor",
290
- minimum=1,
291
- maximum=4,
292
- step=1,
293
- value=4,
294
- )
295
- controlnet_conditioning_scale = gr.Slider(
296
- label="Controlnet Conditioning Scale",
297
- minimum=0.1,
298
- maximum=1.5,
299
- step=0.1,
300
- value=0.6,
301
- )
302
- seed = gr.Slider(
303
- label="Seed",
304
- minimum=0,
305
- maximum=MAX_SEED,
306
- step=1,
307
- value=42,
308
- )
309
-
310
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
311
-
312
- with gr.Row():
313
- result = gr.Image(label="Output Image", type="pil", interactive=False)
314
-
315
- gr.Button("Run").click(
316
- fn=gr_infer,
317
- inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor, controlnet_conditioning_scale],
318
- outputs=result
319
- )
320
-
321
- gr.Markdown("**Disclaimer:**")
322
- gr.Markdown(
323
- "This demo is only for research purposes. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
324
- )
325
-
326
  if __name__ == '__main__':
327
- demo.queue().launch(share=False, show_api=False)
328
- app.run(host='0.0.0.0', port=5000)
329
 
330
- # if __name__ == '__main__':
331
- # app.run(debug=True,host='0.0.0.0')
332
 
333
 
 
1
+ # import logging
2
+ # import random
3
+ # import warnings
4
+ # import gradio as gr
5
+ # import os
6
+ # import shutil
7
+ # import subprocess
8
+ # import spaces
9
+ # import torch
10
+ # import numpy as np
11
+ # from diffusers import FluxControlNetModel
12
+ # from diffusers.pipelines import FluxControlNetPipeline
13
+ # from PIL import Image
14
+ # from huggingface_hub import snapshot_download, login
15
+ # import io
16
+ # import base64
17
+ # from flask import Flask, request, jsonify
18
+ # from concurrent.futures import ThreadPoolExecutor
19
+ # from flask_cors import CORS
20
+ # import threading
21
+
22
+ # # Configure logging
23
+ # logging.basicConfig(level=logging.INFO)
24
+ # logger = logging.getLogger(__name__)
25
+
26
+ # app = Flask(__name__)
27
+ # CORS(app)
28
+
29
+ # # Function to check disk usage
30
+ # def check_disk_space():
31
+ # result = subprocess.run(['df', '-h'], capture_output=True, text=True)
32
+ # logger.info("Disk space usage:\n%s", result.stdout)
33
+
34
+ # # Function to clear Hugging Face cache
35
+ # def clear_huggingface_cache():
36
+ # cache_dir = os.path.expanduser('~/.cache/huggingface')
37
+ # if os.path.exists(cache_dir):
38
+ # shutil.rmtree(cache_dir) # Removes the entire cache directory
39
+ # logger.info("Cleared Hugging Face cache at: %s", cache_dir)
40
+ # else:
41
+ # logger.info("No Hugging Face cache found.")
42
+
43
+ # # Check disk space
44
+ # check_disk_space()
45
+
46
+ # # Clear Hugging Face cache
47
+ # clear_huggingface_cache()
48
+
49
+ # # Add config to store base64 images
50
+ # app.config['image_outputs'] = {}
51
+
52
+ # # ThreadPoolExecutor for managing image processing threads
53
+ # executor = ThreadPoolExecutor()
54
+
55
+ # # Determine the device (GPU or CPU)
56
+ # if torch.cuda.is_available():
57
+ # device = "cuda"
58
+ # logger.info("CUDA is available. Using GPU.")
59
+ # else:
60
+ # device = "cpu"
61
+ # logger.info("CUDA is not available. Using CPU.")
62
+
63
+ # # Load model from Huggingface Hub
64
+ # huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
65
+ # if huggingface_token:
66
+ # login(token=huggingface_token)
67
+ # logger.info("Hugging Face token found and logged in.")
68
+ # else:
69
+ # logger.warning("Hugging Face token not found in environment variables.")
70
+
71
+ # logger.info("Hugging Face token: %s", huggingface_token)
72
+
73
+ # # Download model using snapshot_download
74
+
75
+ # model_path = snapshot_download(
76
+ # repo_id="black-forest-labs/FLUX.1-dev",
77
+ # repo_type="model",
78
+ # ignore_patterns=["*.md", "*..gitattributes"],
79
+ # local_dir="FLUX.1-dev",
80
+ # token=huggingface_token)
81
+ # logger.info("Model downloaded to: %s", model_path)
82
+
83
+ # # Load pipeline
84
+ # logger.info('Loading ControlNet model.')
85
+
86
+ # controlnet = FluxControlNetModel.from_pretrained(
87
+ # "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
88
+ # ).to(device)
89
+ # logger.info("ControlNet model loaded successfully.")
90
+
91
+ # logger.info('Loading pipeline.')
92
+
93
+ # pipe = FluxControlNetPipeline.from_pretrained(
94
+ # model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
95
+ # ).to(device)
96
+ # logger.info("Pipeline loaded successfully.")
97
+
98
+ # MAX_SEED = 1000000
99
+ # MAX_PIXEL_BUDGET = 1024 * 1024
100
+
101
+
102
+ # @spaces.GPU
103
+ # def process_input(input_image, upscale_factor):
104
+ # w, h = input_image.size
105
+ # aspect_ratio = w / h
106
+ # was_resized = False
107
+
108
+ # # Resize if input size exceeds the maximum pixel budget
109
+ # if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
110
+ # warnings.warn(f"Requested output image is too large. Resizing to fit within pixel budget.")
111
+ # input_image = input_image.resize(
112
+ # (
113
+ # int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
114
+ # int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
115
+ # )
116
+ # )
117
+ # was_resized = True
118
+
119
+ # # Adjust dimensions to be a multiple of 8
120
+ # w, h = input_image.size
121
+ # w = w - w % 8
122
+ # h = h - h % 8
123
+
124
+ # return input_image.resize((w, h)), was_resized
125
+
126
+ # @spaces.GPU
127
+ # def run_inference(process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale):
128
+ # logger.info("Processing inference for process_id: %s", process_id)
129
+ # input_image, was_resized = process_input(input_image, upscale_factor)
130
+
131
+ # # Rescale image for ControlNet processing
132
+ # w, h = input_image.size
133
+ # control_image = input_image.resize((w * upscale_factor, h * upscale_factor))
134
+
135
+ # # Set the random generator for inference
136
+ # generator = torch.Generator().manual_seed(seed)
137
+
138
+ # # Perform inference using the pipeline
139
+ # logger.info("Running pipeline for process_id: %s", process_id)
140
+ # image = pipe(
141
+ # prompt="",
142
+ # control_image=control_image,
143
+ # controlnet_conditioning_scale=controlnet_conditioning_scale,
144
+ # num_inference_steps=num_inference_steps,
145
+ # guidance_scale=3.5,
146
+ # height=control_image.size[1],
147
+ # width=control_image.size[0],
148
+ # generator=generator,
149
+ # ).images[0]
150
+
151
+ # # Resize output image back to the original dimensions if needed
152
+ # if was_resized:
153
+ # original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor)
154
+ # image = image.resize(original_size)
155
+
156
+ # # Convert the output image to base64
157
+ # buffered = io.BytesIO()
158
+ # image.save(buffered, format="JPEG")
159
+ # image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
160
+
161
+ # # Store the result in the shared dictionary
162
+ # app.config['image_outputs'][process_id] = image_base64
163
+ # logger.info("Inference completed for process_id: %s", process_id)
164
+
165
+ # @app.route('/infer', methods=['POST'])
166
+ # def infer():
167
+ # # Check if the file was provided in the form-data
168
+ # if 'input_image' not in request.files:
169
+ # logger.error("No image file provided in request.")
170
+ # return jsonify({
171
+ # "status": "error",
172
+ # "message": "No input_image file provided"
173
+ # }), 400
174
+
175
+ # # Get the uploaded image file from the request
176
+ # file = request.files['input_image']
177
+
178
+ # # Check if a file was uploaded
179
+ # if file.filename == '':
180
+ # logger.error("No selected file in form-data.")
181
+ # return jsonify({
182
+ # "status": "error",
183
+ # "message": "No selected file"
184
+ # }), 400
185
+
186
+ # # Convert the image to Base64 for internal processing
187
+ # input_image = Image.open(file)
188
+ # buffered = io.BytesIO()
189
+ # input_image.save(buffered, format="JPEG")
190
+
191
+ # # Retrieve additional parameters from the request (if any)
192
+ # seed = request.form.get("seed", 42, type=int)
193
+ # randomize_seed = request.form.get("randomize_seed", 'true').lower() == 'true'
194
+ # num_inference_steps = request.form.get("num_inference_steps", 28, type=int)
195
+ # upscale_factor = request.form.get("upscale_factor", 4, type=int)
196
+ # controlnet_conditioning_scale = request.form.get("controlnet_conditioning_scale", 0.6, type=float)
197
+
198
+ # # Randomize seed if specified
199
+ # if randomize_seed:
200
+ # seed = random.randint(0, MAX_SEED)
201
+ # logger.info("Seed randomized to: %d", seed)
202
+
203
+ # # Create a unique process ID for this request
204
+ # process_id = str(random.randint(1000, 9999))
205
+ # logger.info("Process started with process_id: %s", process_id)
206
+
207
+ # # Set the status to 'in_progress'
208
+ # app.config['image_outputs'][process_id] = None
209
+
210
+ # # Run the inference in a separate thread
211
+ # executor.submit(run_inference, process_id, input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale)
212
+
213
+ # # Return the process ID
214
+ # return jsonify({
215
+ # "process_id": process_id,
216
+ # "message": "Processing started"
217
+ # })
218
+
219
+
220
+ # # Modify status endpoint to receive process_id in request body
221
+ # @app.route('/status', methods=['GET'])
222
+ # def status():
223
+ # # Get the process_id from the query parameters
224
+ # process_id = request.args.get('process_id')
225
+
226
+ # # Check if process_id was provided
227
+ # if not process_id:
228
+ # logger.error("Process ID not provided in request.")
229
+ # return jsonify({
230
+ # "status": "error",
231
+ # "message": "Process ID is required"
232
+ # }), 400
233
+
234
+ # # Check if the process_id exists in the dictionary
235
+ # if process_id not in app.config['image_outputs']:
236
+ # logger.error("Invalid process ID: %s", process_id)
237
+ # return jsonify({
238
+ # "status": "error",
239
+ # "message": "Invalid process ID"
240
+ # }), 404
241
+
242
+ # # Check the status of the image processing
243
+ # image_base64 = app.config['image_outputs'][process_id]
244
+ # if image_base64 is None:
245
+ # logger.info("Process ID %s is still in progress.", process_id)
246
+ # return jsonify({
247
+ # "status": "in_progress"
248
+ # })
249
+ # else:
250
+ # logger.info("Process ID %s completed successfully.", process_id)
251
+ # return jsonify({
252
+ # "status": "completed",
253
+ # "output_image": image_base64
254
+ # })
255
+
256
+
257
+ # if __name__ == '__main__':
258
+ # app.run(debug=True,host='0.0.0.0')
259
+
260
+
261
+
262
+
263
+
264
  import logging
265
  import random
266
  import warnings
 
268
  import os
269
  import shutil
270
  import subprocess
 
271
  import torch
272
  import numpy as np
273
  from diffusers import FluxControlNetModel
 
276
  from huggingface_hub import snapshot_download, login
277
  import io
278
  import base64
 
 
 
279
  import threading
280
 
281
  # Configure logging
282
  logging.basicConfig(level=logging.INFO)
283
  logger = logging.getLogger(__name__)
284
 
 
 
 
285
  # Function to check disk usage
286
  def check_disk_space():
287
  result = subprocess.run(['df', '-h'], capture_output=True, text=True)
 
303
  clear_huggingface_cache()
304
 
305
  # Add config to store base64 images
306
+ image_outputs = {}
 
 
 
307
 
308
  # Determine the device (GPU or CPU)
309
  if torch.cuda.is_available():
 
321
  else:
322
  logger.warning("Hugging Face token not found in environment variables.")
323
 
 
 
324
  # Download model using snapshot_download
 
325
  model_path = snapshot_download(
326
+ repo_id="black-forest-labs/FLUX.1-dev",
327
+ repo_type="model",
328
+ ignore_patterns=["*.md", "*..gitattributes"],
329
+ local_dir="FLUX.1-dev",
330
+ token=huggingface_token
331
+ )
332
  logger.info("Model downloaded to: %s", model_path)
333
 
334
  # Load pipeline
335
  logger.info('Loading ControlNet model.')
 
336
  controlnet = FluxControlNetModel.from_pretrained(
337
+ "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
338
+ ).to(device)
339
  logger.info("ControlNet model loaded successfully.")
340
 
341
  logger.info('Loading pipeline.')
 
342
  pipe = FluxControlNetPipeline.from_pretrained(
343
+ model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
344
+ ).to(device)
345
  logger.info("Pipeline loaded successfully.")
346
 
347
  MAX_SEED = 1000000
348
  MAX_PIXEL_BUDGET = 1024 * 1024
349
 
 
 
350
  def process_input(input_image, upscale_factor):
351
  w, h = input_image.size
352
  aspect_ratio = w / h
 
370
 
371
  return input_image.resize((w, h)), was_resized
372
 
373
+ def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale):
374
+ logger.info("Running inference")
 
375
  input_image, was_resized = process_input(input_image, upscale_factor)
376
 
377
  # Rescale image for ControlNet processing
 
382
  generator = torch.Generator().manual_seed(seed)
383
 
384
  # Perform inference using the pipeline
385
+ logger.info("Running pipeline")
386
  image = pipe(
387
  prompt="",
388
  control_image=control_image,
 
404
  image.save(buffered, format="JPEG")
405
  image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
406
 
407
+ logger.info("Inference completed")
408
+ return image_base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
+ # Define Gradio interface
411
+ def gradio_interface(input_image, upscale_factor=4, seed=42, num_inference_steps=28, controlnet_conditioning_scale=0.6):
412
  if randomize_seed:
413
  seed = random.randint(0, MAX_SEED)
414
  logger.info("Seed randomized to: %d", seed)
415
 
416
+ # Run inference
417
+ output_image_base64 = run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale)
418
+
419
+ return Image.open(io.BytesIO(base64.b64decode(output_image_base64)))
420
+
421
+ # Create Gradio interface
422
+ iface = gr.Interface(
423
+ fn=gradio_interface,
424
+ inputs=[
425
+ gr.Image(type="pil", label="Input Image"),
426
+ gr.Slider(min=1, max=8, step=1, label="Upscale Factor"),
427
+ gr.Slider(min=0, max=MAX_SEED, step=1, label="Seed"),
428
+ gr.Slider(min=1, max=100, step=1, label="Inference Steps"),
429
+ gr.Slider(min=0.0, max=1.0, step=0.1, label="ControlNet Conditioning Scale")
430
+ ],
431
+ outputs=gr.Image(label="Output Image"),
432
+ title="ControlNet Image Upscaling",
433
+ description="Upload an image to upscale using the ControlNet model."
434
+ )
435
+
436
+ # Launch Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  if __name__ == '__main__':
438
+ iface.launch()
 
439
 
 
 
440
 
441