Texttra commited on
Commit
00bb2a2
·
verified ·
1 Parent(s): 677c8e2

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +48 -0
handler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path: str = ""):
9
+ print(f"Initializing SDXL model from: {path}")
10
+
11
+ # Base SDXL Model
12
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
13
+ "stabilityai/stable-diffusion-xl-base-1.0",
14
+ torch_dtype=torch.float16,
15
+ variant="fp16"
16
+ )
17
+
18
+ print("Loading LoRA weights from: Texttra/Bh0r")
19
+ self.pipe.load_lora_weights("Texttra/Bh0r", weight_name="Bh0r-10.safetensors")
20
+ self.pipe.fuse_lora()
21
+
22
+ self.pipe.to("cuda" if torch.cuda.is_available() else "cpu")
23
+ print("Model ready.")
24
+
25
+ def __call__(self, data: Dict) -> Dict:
26
+ print("Received data:", data)
27
+
28
+ inputs = data.get("inputs", {})
29
+ prompt = inputs.get("prompt", "")
30
+ print("Extracted prompt:", prompt)
31
+
32
+ if not prompt:
33
+ return {"error": "No prompt provided."}
34
+
35
+ image = self.pipe(
36
+ prompt,
37
+ num_inference_steps=35,
38
+ guidance_scale=7.0,
39
+ ).images[0]
40
+
41
+ print("Image generated.")
42
+
43
+ buffer = BytesIO()
44
+ image.save(buffer, format="PNG")
45
+ base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
46
+ print("Returning image.")
47
+
48
+ return {"image": base64_image}