adorabook commited on
Commit
a41b897
·
verified ·
1 Parent(s): 5592f9d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -16
handler.py CHANGED
@@ -1,24 +1,42 @@
1
  import torch
2
  import numpy as np
3
- from PIL import Image
4
- from transformers import AutoTokenizer
5
  from pulid.pipeline_v1_1 import PuLIDPipeline
6
  from pulid.utils import resize_numpy_image_long
7
  from pulid import attention_processor as attention
8
 
 
9
  torch.set_grad_enabled(False)
10
 
11
  class EndpointHandler:
12
- def __init__(self, model_dir=None):
13
- # Initialize the model and tokenizer
 
 
 
 
 
14
  self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde')
15
  self.default_cfg = 7.0
16
  self.default_steps = 25
17
  self.attention = attention
18
  self.pipeline.debug_img_list = []
19
 
20
- def preprocess(self, inputs):
21
- # Extracts image and parameters from the input data
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  id_image = inputs[0]
23
  supp_images = inputs[1:4]
24
  prompt = inputs[4]
@@ -36,7 +54,7 @@ class EndpointHandler:
36
  if seed == -1:
37
  seed = torch.Generator(device="cpu").seed()
38
 
39
- # Handle the ortho settings
40
  if ortho == 'v2':
41
  self.attention.ORTHO = False
42
  self.attention.ORTHO_v2 = True
@@ -47,7 +65,7 @@ class EndpointHandler:
47
  self.attention.ORTHO = False
48
  self.attention.ORTHO_v2 = False
49
 
50
- # Process the images
51
  if id_image is not None:
52
  id_image = resize_numpy_image_long(id_image, 1024)
53
  supp_id_image_list = [
@@ -59,19 +77,14 @@ class EndpointHandler:
59
  uncond_id_embedding = None
60
  id_embedding = None
61
 
62
- return (prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, uncond_id_embedding, id_embedding)
63
-
64
- def predict(self, inputs):
65
- # Preprocess the input data
66
- (prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, uncond_id_embedding, id_embedding) = self.preprocess(inputs)
67
-
68
- # Run the inference pipeline
69
  img = self.pipeline.inference(
70
  prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
71
  )[0]
72
 
 
73
  return {
74
  "image": np.array(img).tolist(),
75
  "seed": str(seed),
76
  "debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list],
77
- }
 
1
  import torch
2
  import numpy as np
3
+ from typing import Dict, Any
 
4
  from pulid.pipeline_v1_1 import PuLIDPipeline
5
  from pulid.utils import resize_numpy_image_long
6
  from pulid import attention_processor as attention
7
 
8
+ # Disable gradients for inference
9
  torch.set_grad_enabled(False)
10
 
11
  class EndpointHandler:
12
+ def __init__(self, model_dir: str = None):
13
+ """
14
+ Initializes the model and necessary components.
15
+
16
+ Args:
17
+ model_dir (str): Directory containing the model weights.
18
+ """
19
  self.pipeline = PuLIDPipeline(sdxl_repo='RunDiffusion/Juggernaut-XL-v9', sampler='dpmpp_sde')
20
  self.default_cfg = 7.0
21
  self.default_steps = 25
22
  self.attention = attention
23
  self.pipeline.debug_img_list = []
24
 
25
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
26
+ """
27
+ Handles inference requests.
28
+
29
+ Args:
30
+ data (Dict[str, Any]): Input data for inference.
31
+
32
+ Returns:
33
+ Dict[str, Any]: Results containing the generated image and debug information.
34
+ """
35
+ # Preprocess inputs
36
+ inputs = data.get("inputs", [])
37
+ if not inputs or len(inputs) < 14:
38
+ raise ValueError("Invalid inputs. Expected 14 elements in the input list.")
39
+
40
  id_image = inputs[0]
41
  supp_images = inputs[1:4]
42
  prompt = inputs[4]
 
54
  if seed == -1:
55
  seed = torch.Generator(device="cpu").seed()
56
 
57
+ # Handle orthogonal settings
58
  if ortho == 'v2':
59
  self.attention.ORTHO = False
60
  self.attention.ORTHO_v2 = True
 
65
  self.attention.ORTHO = False
66
  self.attention.ORTHO_v2 = False
67
 
68
+ # Process images
69
  if id_image is not None:
70
  id_image = resize_numpy_image_long(id_image, 1024)
71
  supp_id_image_list = [
 
77
  uncond_id_embedding = None
78
  id_embedding = None
79
 
80
+ # Generate image
 
 
 
 
 
 
81
  img = self.pipeline.inference(
82
  prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
83
  )[0]
84
 
85
+ # Prepare response
86
  return {
87
  "image": np.array(img).tolist(),
88
  "seed": str(seed),
89
  "debug_images": [np.array(debug_img).tolist() for debug_img in self.pipeline.debug_img_list],
90
+ }