sachin commited on
Commit
cb5e4aa
·
1 Parent(s): c43570d

add-genrate

Browse files
Files changed (1) hide show
  1. intruct.py +117 -0
intruct.py CHANGED
@@ -5,6 +5,18 @@ import math
5
  from PIL import Image, ImageOps
6
  import torch
7
  from diffusers import StableDiffusionInstructPix2PixPipeline
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Initialize FastAPI app
10
  app = FastAPI()
@@ -21,6 +33,111 @@ DEFAULT_TEXT_CFG = 7.5
21
  DEFAULT_IMAGE_CFG = 1.5
22
  DEFAULT_SEED = 1371
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
25
  """
26
  Process the input image with the given instruction using InstructPix2Pix.
 
5
  from PIL import Image, ImageOps
6
  import torch
7
  from diffusers import StableDiffusionInstructPix2PixPipeline
8
+ from fastapi import FastAPI, Response
9
+ from fastapi.responses import FileResponse
10
+ import torch
11
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
12
+ from huggingface_hub import hf_hub_download, login
13
+ from safetensors.torch import load_file
14
+ from io import BytesIO
15
+ import os
16
+ import base64 # Added for encoding images as base64
17
+ from typing import List # Added for type hinting the list of prompts
18
+
19
+
20
 
21
  # Initialize FastAPI app
22
  app = FastAPI()
 
33
  DEFAULT_IMAGE_CFG = 1.5
34
  DEFAULT_SEED = 1371
35
 
36
+
37
+ HF_TOKEN = os.getenv("HF_TOKEN")
38
+
39
+ def load_model():
40
+ try:
41
+ # Login to Hugging Face if token is provided
42
+ if HF_TOKEN:
43
+ login(token=HF_TOKEN)
44
+
45
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
46
+ repo = "ByteDance/SDXL-Lightning"
47
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
48
+
49
+ # Load model with explicit error handling
50
+ unet = UNet2DConditionModel.from_config(
51
+ base,
52
+ subfolder="unet"
53
+ ).to("cuda", torch.float16)
54
+
55
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
56
+ pipe = StableDiffusionXLPipeline.from_pretrained(
57
+ base,
58
+ unet=unet,
59
+ torch_dtype=torch.float16,
60
+ variant="fp16"
61
+ ).to("cuda")
62
+
63
+ # Configure scheduler
64
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
65
+ pipe.scheduler.config,
66
+ timestep_spacing="trailing"
67
+ )
68
+
69
+ return pipe
70
+
71
+ except Exception as e:
72
+ raise Exception(f"Failed to load model: {str(e)}")
73
+
74
+ # Load model at startup with error handling
75
+ try:
76
+ pipe = load_model()
77
+ except Exception as e:
78
+ print(f"Model initialization failed: {str(e)}")
79
+ raise
80
+
81
+
82
+
83
+ @app.get("/generate")
84
+ async def generate_image(prompt: str):
85
+ try:
86
+ # Generate image
87
+ image = pipe(
88
+ prompt,
89
+ num_inference_steps=4,
90
+ guidance_scale=0
91
+ ).images[0]
92
+
93
+ # Save image to buffer
94
+ buffer = BytesIO()
95
+ image.save(buffer, format="PNG")
96
+ buffer.seek(0)
97
+
98
+ return Response(content=buffer.getvalue(), media_type="image/png")
99
+
100
+ except Exception as e:
101
+ return {"error": str(e)}
102
+
103
+ # New endpoint to handle a list of prompts
104
+ @app.get("/generate_multiple")
105
+ async def generate_multiple_images(prompts: List[str]):
106
+ try:
107
+ # List to store base64-encoded images
108
+ generated_images = []
109
+
110
+ # Generate an image for each prompt
111
+ for prompt in prompts:
112
+ image = pipe(
113
+ prompt,
114
+ num_inference_steps=4,
115
+ guidance_scale=0
116
+ ).images[0]
117
+
118
+ # Save image to buffer
119
+ buffer = BytesIO()
120
+ image.save(buffer, format="PNG")
121
+ buffer.seek(0)
122
+
123
+ # Encode the image as base64
124
+ image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
125
+ generated_images.append({
126
+ "prompt": prompt,
127
+ "image_base64": image_base64
128
+ })
129
+
130
+ return {"images": generated_images}
131
+
132
+ except Exception as e:
133
+ return {"error": str(e)}
134
+
135
+ @app.get("/health")
136
+ async def health_check():
137
+ return {"status": "healthy"}
138
+
139
+
140
+
141
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
142
  """
143
  Process the input image with the given instruction using InstructPix2Pix.