File size: 2,988 Bytes
fe706b1
 
e757853
 
b566466
e757853
fe706b1
 
 
 
 
 
 
9c175fd
 
fe706b1
e757853
fe706b1
e757853
 
 
 
 
 
 
 
fe706b1
 
 
 
 
 
 
 
 
19d79cd
e757853
b566466
 
e757853
 
fb68c01
fe706b1
 
e757853
fb68c01
 
 
 
 
fe706b1
e757853
9c175fd
e757853
 
9c175fd
e757853
 
 
fe706b1
 
 
 
 
fb68c01
e757853
 
 
 
 
 
 
 
 
 
 
 
 
 
418fd63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import  Dict, List, Any
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DDIMScheduler

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if device.type != 'cuda':
    raise ValueError("need to run on GPU")

model_id = "stabilityai/stable-diffusion-2-1-base"

class EndpointHandler():
    def __init__(self, path=""):
        # load the optimized model
        self.textPipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.textPipe.scheduler = DDIMScheduler.from_config(self.textPipe.scheduler.config)
        self.textPipe = self.textPipe.to(device)

        # create an img2img model
        self.imgPipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        self.imgPipe.scheduler = DDIMScheduler.from_config(self.imgPipe.scheduler.config)
        self.imgPipe = self.imgPipe.to(device)

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        """
        Args:
            data (:obj:):
                includes the input data and the parameters for the inference.
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        prompt = data.pop("inputs", data)
        url = data.pop("url", data)
        response = requests.get(url)
        init_image = Image.open(BytesIO(response.content)).convert("RGB")
        init_image.thumbnail((512, 512))

        params = data.pop("parameters", data)

        # hyperparamters
        num_inference_steps = params.pop("num_inference_steps", 25)
        guidance_scale = params.pop("guidance_scale", 7.5)
        negative_prompt = params.pop("negative_prompt", None)
        height = params.pop("height", None)
        width = params.pop("width", None)
        manual_seed = params.pop("manual_seed", -1)

        out = None

        if data.get("url"):
            generator = torch.Generator(device='cuda')
            generator.manual_seed(manual_seed)
            # run img2img pipeline
            out = self.imgPipe(prompt, 
                        image=init_image,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        num_images_per_prompt=1,
                        negative_prompt=negative_prompt,
                        height=height,
                        width=width
            )
        else:
            # run text pipeline
            out = self.textPipe(prompt, 
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        num_images_per_prompt=1,
                        negative_prompt=negative_prompt,
                        height=height,
                        width=width
            )


        # return first generated PIL image
        return out.images[0]