atharvapawar commited on
Commit
1adc858
·
verified ·
1 Parent(s): af788e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from diffusers import StableDiffusionUpscalePipeline
4
+ from diffusers.utils import load_image
5
+ import torch
6
+ from PIL import Image
7
+ import base64
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # Load model and scheduler
12
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
13
+ pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
14
+ pipeline = pipeline.to("cuda")
15
+
16
+ def upscale_image(image, prompt):
17
+ image = image.resize((128, 128)) # Resize to the expected input size
18
+ upscaled_image = pipeline(prompt=prompt, image=image).images[0]
19
+ return upscaled_image
20
+
21
+ def image_to_base64(image):
22
+ buffered = BytesIO()
23
+ image.save(buffered, format="JPEG")
24
+ return base64.b64encode(buffered.getvalue()).decode()
25
+
26
+ def base64_to_image(base64_str):
27
+ image_data = base64.b64decode(base64_str)
28
+ return Image.open(BytesIO(image_data))
29
+
30
+ def process_image(image):
31
+ prompt = "a white cat"
32
+ upscaled_image = upscale_image(image, prompt)
33
+ return image_to_base64(upscaled_image)
34
+
35
+ def main():
36
+ with gr.Blocks() as demo:
37
+ gr.Markdown("# Stable Diffusion Upscaler")
38
+
39
+ with gr.Row():
40
+ with gr.Column(scale=1):
41
+ image_input = gr.Image(type="pil", label="Low-Resolution Image", tool="editor")
42
+ prompt_input = gr.Textbox(label="Prompt", value="a white cat")
43
+
44
+ upload_btn = gr.Button("Upload and Upscale")
45
+ base64_output = gr.Textbox(label="Base64 Encoded Image")
46
+ upscaled_image_output = gr.Image(type="pil", label="Upscaled Image")
47
+
48
+ def handle_upload(image, prompt):
49
+ upscaled_image = upscale_image(image, prompt)
50
+ base64_str = image_to_base64(upscaled_image)
51
+ return base64_str, upscaled_image
52
+
53
+ upload_btn.click(fn=handle_upload, inputs=[image_input, prompt_input], outputs=[base64_output, upscaled_image_output])
54
+
55
+ demo.launch()
56
+
57
+ if __name__ == "__main__":
58
+ main()