Diffusers
Safetensors
StableDiffusionUpscalePipeline
stable-diffusion
yanis9351 commited on
Commit
9b535c8
·
verified ·
1 Parent(s): 572c992

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -0
handler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from diffusers import StableDiffusionUpscalePipeline
3
+ import torch
4
+ from PIL import Image
5
+ import io
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path: str):
9
+ # Load the Stable Diffusion x4 upscaler model
10
+ self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
11
+ path,
12
+ torch_dtype=torch.float16
13
+ )
14
+ self.pipeline.to("cuda")
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
+ """
18
+ data args:
19
+ inputs: str - The text prompt for the upscaling.
20
+ image: bytes - The low-resolution image as byte data.
21
+
22
+ Return:
23
+ A list of dictionaries with the upscaled image.
24
+ """
25
+ # Extract inputs and image from the payload
26
+ prompt = data.get("inputs", "")
27
+ image_bytes = data.get("image", None)
28
+
29
+ if image_bytes is None:
30
+ return [{"error": "No image provided"}]
31
+
32
+ # Convert the byte data to an image
33
+ low_res_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
34
+
35
+ # Perform upscaling
36
+ upscaled_image = self.pipeline(prompt=prompt, image=low_res_img).images[0]
37
+
38
+ # Save the upscaled image to a byte stream
39
+ byte_io = io.BytesIO()
40
+ upscaled_image.save(byte_io, format="PNG")
41
+ byte_io.seek(0)
42
+
43
+ # Return the upscaled image as byte data
44
+ return [{"upscaled_image": byte_io.getvalue()}]