rgres commited on
Commit
16b3d38
·
1 Parent(s): 914a5ac

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +10 -0
  2. handler.py +62 -0
  3. requirements.txt +3 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ tags:
4
+ - stable-diffusion
5
+ - stable-diffusion-diffusers
6
+ - controlnet
7
+ inference: true
8
+ ---
9
+
10
+ # Inference Endpoint for [ControlNet](https://huggingface.co/lllyasviel/ControlNet) using [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
handler.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
2
+ from typing import Dict, List, Any
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import base64
6
+ import torch
7
+
8
+ # set device
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+ if device.type != 'cuda':
11
+ raise ValueError("need to run on GPU")
12
+ # set mixed precision dtype
13
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
14
+
15
+
16
+ class EndpointHandler():
17
+ def __init__(self, path=""):
18
+ self.stable_diffusion_id = "stabilityai/stable-diffusion-2-1-base"
19
+
20
+ controlnet = ControlNetModel.from_pretrained("rgres/sd-controlnet-aerialdreams", torch_dtype=torch.float16)
21
+
22
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
23
+ self.stable_diffusion_id, controlnet=controlnet, torch_dtype=dtype, safety_checker=None
24
+ ).to(device)
25
+
26
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
27
+ """
28
+ :param data: A dictionary contains `inputs` and optional `image` field.
29
+ :return: A dictionary with `image` field contains image in base64.
30
+ """
31
+ prompt = data.pop("prompt", None)
32
+ image = data.pop("image", None)
33
+ steps = data.pop("steps", 30)
34
+ seed = data.pop("seed", None)
35
+
36
+ # Check if neither prompt nor image is provided
37
+ if prompt is None and image is None:
38
+ return {"error": "Please provide a prompt and base64 encoded image."}
39
+
40
+ # decode image
41
+ image = self.decode_base64_image(image)
42
+
43
+ self.generator = torch.Generator(device="cpu").manual_seed(3)
44
+
45
+ # run inference pipeline
46
+ image_out = self.pipe(
47
+ prompt=prompt,
48
+ image=image,
49
+ num_inference_steps=steps,
50
+ generator=self.generator
51
+ ).images[0]
52
+
53
+ # return first generate PIL image
54
+ return image_out
55
+
56
+
57
+ # helper to decode input image
58
+ def decode_base64_image(self, image_string):
59
+ base64_image = base64.b64decode(image_string)
60
+ buffer = BytesIO(base64_image)
61
+ image = Image.open(buffer)
62
+ return image
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/huggingface/diffusers.git
2
+ safetensors
3
+ opencv-python