zwww commited on
Commit
acd5cdb
·
1 Parent(s): f845ca1

Upload serve_loras.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. serve_loras.py +195 -0
serve_loras.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from compel import Compel, ReturnedEmbeddingsType
2
+ import logging
3
+ from abc import ABC
4
+
5
+ import diffusers
6
+ import torch
7
+ from diffusers import StableDiffusionXLPipeline
8
+
9
+ import numpy as np
10
+ import threading
11
+
12
+ import base64
13
+ from io import BytesIO
14
+ from PIL import Image
15
+ import numpy as np
16
+ import uuid
17
+ from tempfile import TemporaryFile
18
+ from google.cloud import storage
19
+ import sys
20
+ import sentry_sdk
21
+ from flask import Flask, request, jsonify
22
+ import os
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logger.info("Diffusers version %s", diffusers.__version__)
26
+
27
+ sentry_sdk.init(
28
+ dsn="https://f750d1b039d66541f344ee6151d38166@o4505891057696768.ingest.sentry.io/4506071735205888",
29
+ )
30
+
31
+ LORAS_DIR = './safetensors'
32
+
33
+ class DiffusersHandler(ABC):
34
+ """
35
+ Diffusers handler class for text to image generation.
36
+ """
37
+
38
+ def __init__(self):
39
+ self.initialized = False
40
+
41
+ def initialize(self, properties):
42
+ """In this initialize function, the Stable Diffusion model is loaded and
43
+ initialized here.
44
+ Args:
45
+ ctx (context): It is a JSON Object containing information
46
+ pertaining to the model artefacts parameters.
47
+ """
48
+
49
+ logger.info("Loading diffusion model")
50
+ logger.info("I'm totally new and updated")
51
+
52
+
53
+ device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
54
+
55
+ print("my device is " + device_str)
56
+ self.device = torch.device(device_str)
57
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
58
+ sys.argv[1],
59
+ torch_dtype=torch.float16,
60
+ use_safetensors=True,
61
+ )
62
+
63
+ logger.info("moving model to device: %s", device_str)
64
+ self.pipe.to(self.device)
65
+
66
+ logger.info(self.device)
67
+ logger.info("Diffusion model from path %s loaded successfully")
68
+
69
+ self.initialized = True
70
+
71
+ def preprocess(self, raw_requests):
72
+ """Basic text preprocessing, of the user's prompt.
73
+ Args:
74
+ requests (str): The Input data in the form of text is passed on to the preprocess
75
+ function.
76
+ Returns:
77
+ list : The preprocess function returns a list of prompts.
78
+ """
79
+ logger.info("Received requests: '%s'", raw_requests)
80
+ self.working = True
81
+
82
+ processed_request = {
83
+ "prompt": raw_requests[0]["prompt"],
84
+ "negative_prompt": raw_requests[0].get("negative_prompt"),
85
+ "width": raw_requests[0].get("width"),
86
+ "height": raw_requests[0].get("height"),
87
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
88
+ "guidance_scale": raw_requests[0].get("guidance_scale", 7.5),
89
+ "lora_weights": raw_requests[0].get("lora_name", None),
90
+ "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.6)}
91
+ }
92
+
93
+ logger.info("Processed request: '%s'", processed_request)
94
+ return processed_request
95
+
96
+
97
+ def inference(self, request):
98
+ """Generates the image relevant to the received text.
99
+ Args:
100
+ inputs (list): List of Text from the pre-process function is passed here
101
+ Returns:
102
+ list : It returns a list of the generate images for the input text
103
+ """
104
+
105
+ # Handling inference for sequence_classification.
106
+ compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
107
+
108
+ self.prompt = request.pop("prompt")
109
+ conditioning, pooled = compel(self.prompt)
110
+
111
+ lora_weights = request.pop("lora_weights")
112
+ if lora_weights is not None:
113
+ lora_path = os.path.join(LORAS_DIR, lora_weights + '.safetensors')
114
+ logger.info('LOADING LORA FROM: ' + lora_path)
115
+ self.pipe.load_lora_weights(lora_path)
116
+
117
+ # Handling inference for sequence_classification.
118
+ inferences = self.pipe(
119
+ prompt_embeds=conditioning,
120
+ pooled_prompt_embeds=pooled,
121
+ **request
122
+ ).images
123
+
124
+ if lora_weights is not None:
125
+ self.pipe.unload_lora_weights()
126
+
127
+ logger.info("Generated image: '%s'", inferences)
128
+ return inferences
129
+
130
+ def postprocess(self, inference_outputs):
131
+ """Post Process Function converts the generated image into Torchserve readable format.
132
+ Args:
133
+ inference_outputs (list): It contains the generated image of the input text.
134
+ Returns:
135
+ (list): Returns a list of the images.
136
+ """
137
+ bucket_name = "outputs-storage-prod"
138
+ client = storage.Client()
139
+ self.working = False
140
+ bucket = client.get_bucket(bucket_name)
141
+ outputs = []
142
+ for image in inference_outputs:
143
+ image_name = str(uuid.uuid4())
144
+
145
+ blob = bucket.blob(image_name + '.png')
146
+
147
+ with TemporaryFile() as tmp:
148
+ image.save(tmp, format="png")
149
+ tmp.seek(0)
150
+ blob.upload_from_file(tmp, content_type='image/png')
151
+
152
+ # generate txt file with the image name and the prompt inside
153
+ # blob = bucket.blob(image_name + '.txt')
154
+ # blob.upload_from_string(self.prompt)
155
+
156
+ outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
157
+ return outputs
158
+
159
+
160
+ app = Flask(__name__)
161
+
162
+ # Initialize the handler on startup
163
+ gpu_count = torch.cuda.device_count()
164
+ if gpu_count == 0:
165
+ raise ValueError("No GPUs available!")
166
+
167
+ handlers = [DiffusersHandler() for i in range(gpu_count)]
168
+ for i in range(gpu_count):
169
+ handlers[i].initialize({"gpu_id": i})
170
+
171
+ handler_lock = threading.Lock()
172
+ handler_index = 0
173
+
174
+ @app.route('/generate', methods=['POST'])
175
+ def generate_image():
176
+ global handler_index
177
+ try:
178
+ # Extract raw requests from HTTP POST body
179
+ raw_requests = request.json
180
+
181
+ with handler_lock:
182
+ selected_handler = handlers[handler_index]
183
+ handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler
184
+
185
+ processed_request = selected_handler.preprocess([raw_requests])
186
+ inferences = selected_handler.inference(processed_request)
187
+ outputs = selected_handler.postprocess(inferences)
188
+
189
+ return jsonify({"image_urls": outputs})
190
+ except Exception as e:
191
+ logger.error("Error during image generation: %s", str(e))
192
+ return jsonify({"error": "Failed to generate image", "details": str(e)}), 500
193
+
194
+ if __name__ == '__main__':
195
+ app.run(host='0.0.0.0', port=3000, threaded=True)