rawc0der commited on
Commit
4625abf
·
1 Parent(s): 7d16304

update interface with model gateway

Browse files
Files changed (2) hide show
  1. app.py +396 -4
  2. requirements.txt +18 -0
app.py CHANGED
@@ -1,7 +1,399 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File
2
+ from typing import Optional, Dict, Any
3
+ import torch
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionXLPipeline,
7
+ AutoPipelineForText2Image
8
+ )
9
  import gradio as gr
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gc
13
+ from io import BytesIO
14
+ import base64
15
+ import functools
16
 
17
+ app = FastAPI()
 
18
 
19
+ # Comprehensive model registry
20
+ MODELS = {
21
+ "SDXL-Base": {
22
+ "model_id": "stabilityai/stable-diffusion-xl-base-1.0",
23
+ "pipeline": StableDiffusionXLPipeline,
24
+ "supports_img2img": True,
25
+ "parameters": {
26
+ "num_inference_steps": {"min": 1, "max": 100, "default": 50},
27
+ "guidance_scale": {"min": 1, "max": 15, "default": 7.5},
28
+ "width": {"min": 256, "max": 1024, "default": 512, "step": 64},
29
+ "height": {"min": 256, "max": 1024, "default": 512, "step": 64}
30
+ }
31
+ },
32
+ "SDXL-Turbo": {
33
+ "model_id": "stabilityai/sdxl-turbo",
34
+ "pipeline": AutoPipelineForText2Image,
35
+ "supports_img2img": True,
36
+ "parameters": {
37
+ "num_inference_steps": {"min": 1, "max": 50, "default": 1},
38
+ "guidance_scale": {"min": 0.0, "max": 20.0, "default": 7.5},
39
+ "width": {"min": 256, "max": 1024, "default": 512, "step": 64},
40
+ "height": {"min": 256, "max": 1024, "default": 512, "step": 64}
41
+ }
42
+ },
43
+ "SD-1.5": {
44
+ "model_id": "runwayml/stable-diffusion-v1-5",
45
+ "pipeline": StableDiffusionPipeline,
46
+ "supports_img2img": True,
47
+ "parameters": {
48
+ "num_inference_steps": {"min": 1, "max": 50, "default": 30},
49
+ "guidance_scale": {"min": 1, "max": 20, "default": 7.5},
50
+ "width": {"min": 256, "max": 1024, "default": 512, "step": 64},
51
+ "height": {"min": 256, "max": 1024, "default": 512, "step": 64}
52
+ }
53
+ },
54
+ "Waifu-Diffusion": {
55
+ "model_id": "hakurei/waifu-diffusion",
56
+ "pipeline": StableDiffusionPipeline,
57
+ "supports_img2img": True,
58
+ "parameters": {
59
+ "num_inference_steps": {"min": 1, "max": 100, "default": 50},
60
+ "guidance_scale": {"min": 1, "max": 15, "default": 7.5},
61
+ "width": {"min": 256, "max": 1024, "default": 512, "step": 64},
62
+ "height": {"min": 256, "max": 1024, "default": 512, "step": 64}
63
+ }
64
+ },
65
+ "Flux": {
66
+ "model_id": "black-forest-labs/flux-1-1-dev",
67
+ "pipeline": AutoPipelineForText2Image,
68
+ "supports_img2img": True,
69
+ "parameters": {
70
+ "num_inference_steps": {"min": 1, "max": 50, "default": 25},
71
+ "guidance_scale": {"min": 1, "max": 15, "default": 7.5},
72
+ "width": {"min": 256, "max": 1024, "default": 512, "step": 64},
73
+ "height": {"min": 256, "max": 1024, "default": 512, "step": 64}
74
+ }
75
+ }
76
+ }
77
+
78
+ class ModelManager:
79
+ def __init__(self):
80
+ self.current_model = None
81
+ self.current_pipeline = None
82
+ self.model_cache: Dict[str, Any] = {}
83
+ self._device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ self._dtype = torch.float16 if self._device == "cuda" else torch.float32
85
+
86
+ def _clear_memory(self):
87
+ """Clear CUDA memory and garbage collect"""
88
+ if self.current_pipeline is not None:
89
+ del self.current_pipeline
90
+ self.current_pipeline = None
91
+
92
+ if torch.cuda.is_available():
93
+ torch.cuda.empty_cache()
94
+ torch.cuda.ipc_collect()
95
+
96
+ gc.collect()
97
+
98
+ @functools.lru_cache(maxsize=1)
99
+ def get_model_config(self, model_id: str, pipeline_class):
100
+ """Load and cache model configuration"""
101
+ return pipeline_class.from_pretrained(
102
+ model_id,
103
+ torch_dtype=self._dtype,
104
+ variant="fp16" if self._device == "cuda" else None,
105
+ device_map="auto"
106
+ )
107
+
108
+ def load_model(self, model_name: str):
109
+ """Load model with memory optimization"""
110
+ if self.current_model != model_name:
111
+ self._clear_memory()
112
+
113
+ try:
114
+ model_info = MODELS[model_name]
115
+ self.current_pipeline = self.get_model_config(
116
+ model_info["model_id"],
117
+ model_info["pipeline"]
118
+ )
119
+
120
+ if hasattr(self.current_pipeline, 'enable_xformers_memory_efficient_attention'):
121
+ self.current_pipeline.enable_xformers_memory_efficient_attention()
122
+
123
+ if self._device == "cuda":
124
+ self.current_pipeline.enable_model_cpu_offload()
125
+
126
+ self.current_model = model_name
127
+
128
+ except Exception as e:
129
+ self._clear_memory()
130
+ raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
131
+
132
+ return self.current_pipeline
133
+
134
+ def unload_current_model(self):
135
+ """Explicitly unload current model"""
136
+ self._clear_memory()
137
+ self.current_model = None
138
+
139
+ def get_memory_status(self):
140
+ """Get current memory usage status"""
141
+ if not torch.cuda.is_available():
142
+ return {"status": "CPU Mode"}
143
+
144
+ return {
145
+ "total": torch.cuda.get_device_properties(0).total_memory / 1e9,
146
+ "allocated": torch.cuda.memory_allocated() / 1e9,
147
+ "cached": torch.cuda.memory_reserved() / 1e9,
148
+ "free": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9
149
+ }
150
+
151
+ class ModelContext:
152
+ def __init__(self, model_name: str):
153
+ self.model_name = model_name
154
+
155
+ def __enter__(self):
156
+ return model_manager.load_model(self.model_name)
157
+
158
+ def __exit__(self, exc_type, exc_val, exc_tb):
159
+ if exc_type is not None:
160
+ model_manager.unload_current_model()
161
+
162
+ model_manager = ModelManager()
163
+
164
+ async def generate_image(
165
+ model_name: str,
166
+ prompt: str,
167
+ height: int = 512,
168
+ width: int = 512,
169
+ num_inference_steps: Optional[int] = None,
170
+ guidance_scale: Optional[float] = None,
171
+ reference_image: Optional[Image.Image] = None
172
+ ) -> dict:
173
+ try:
174
+ with ModelContext(model_name) as pipeline:
175
+ pre_mem = model_manager.get_memory_status()
176
+
177
+ # Process reference image if provided
178
+ if reference_image and MODELS[model_name]["supports_img2img"]:
179
+ reference_image = reference_image.resize((width, height))
180
+
181
+ # Generate image
182
+ generation_params = {
183
+ "prompt": prompt,
184
+ "height": height,
185
+ "width": width,
186
+ "num_inference_steps": num_inference_steps or MODELS[model_name]["parameters"]["num_inference_steps"]["default"],
187
+ "guidance_scale": guidance_scale or MODELS[model_name]["parameters"]["guidance_scale"]["default"]
188
+ }
189
+
190
+ if reference_image:
191
+ generation_params["image"] = reference_image
192
+
193
+ image = pipeline(**generation_params).images[0]
194
+
195
+ # Convert to base64
196
+ buffered = BytesIO()
197
+ image.save(buffered, format="PNG")
198
+ img_str = base64.b64encode(buffered.getvalue()).decode()
199
+
200
+ post_mem = model_manager.get_memory_status()
201
+
202
+ return {
203
+ "status": "success",
204
+ "image_base64": img_str,
205
+ "memory": {
206
+ "before": pre_mem,
207
+ "after": post_mem
208
+ }
209
+ }
210
+ except Exception as e:
211
+ model_manager.unload_current_model()
212
+ raise HTTPException(status_code=500, detail=str(e))
213
+
214
+ @app.post("/generate")
215
+ async def generate_image_endpoint(
216
+ model_name: str,
217
+ prompt: str,
218
+ height: int = 512,
219
+ width: int = 512,
220
+ num_inference_steps: Optional[int] = None,
221
+ guidance_scale: Optional[float] = None,
222
+ reference_image: UploadFile = File(None)
223
+ ):
224
+ ref_img = None
225
+ if reference_image:
226
+ content = await reference_image.read()
227
+ ref_img = Image.open(BytesIO(content))
228
+
229
+ return await generate_image(
230
+ model_name=model_name,
231
+ prompt=prompt,
232
+ height=height,
233
+ width=width,
234
+ num_inference_steps=num_inference_steps,
235
+ guidance_scale=guidance_scale,
236
+ reference_image=ref_img
237
+ )
238
+
239
+ @app.get("/memory")
240
+ async def get_memory_status():
241
+ return model_manager.get_memory_status()
242
+
243
+ @app.post("/unload")
244
+ async def unload_model():
245
+ model_manager.unload_current_model()
246
+ return {"status": "success", "message": "Model unloaded"}
247
+
248
+ def create_gradio_interface():
249
+ with gr.Blocks() as interface:
250
+ gr.Markdown("# Text-to-Image Generation Interface")
251
+
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ model_dropdown = gr.Dropdown(
255
+ choices=list(MODELS.keys()),
256
+ value=list(MODELS.keys())[0],
257
+ label="Select Model",
258
+ info="Choose the model for image generation"
259
+ )
260
+
261
+ prompt = gr.Textbox(
262
+ lines=3,
263
+ label="Prompt",
264
+ placeholder="Enter your image description here..."
265
+ )
266
+
267
+ with gr.Row():
268
+ height = gr.Slider(
269
+ minimum=256,
270
+ maximum=1024,
271
+ value=512,
272
+ step=64,
273
+ label="Height"
274
+ )
275
+ width = gr.Slider(
276
+ minimum=256,
277
+ maximum=1024,
278
+ value=512,
279
+ step=64,
280
+ label="Width"
281
+ )
282
+
283
+ with gr.Row():
284
+ num_steps = gr.Slider(
285
+ minimum=1,
286
+ maximum=100,
287
+ value=50,
288
+ step=1,
289
+ label="Number of Inference Steps"
290
+ )
291
+ guidance = gr.Slider(
292
+ minimum=1,
293
+ maximum=15,
294
+ value=7.5,
295
+ step=0.1,
296
+ label="Guidance Scale"
297
+ )
298
+
299
+ reference_image = gr.Image(
300
+ type="pil",
301
+ label="Reference Image (optional)",
302
+ info="Upload an image for img2img generation"
303
+ )
304
+
305
+ with gr.Row():
306
+ generate_btn = gr.Button("Generate", variant="primary")
307
+ unload_btn = gr.Button("Unload Model", variant="secondary")
308
+
309
+ with gr.Column(scale=2):
310
+ output_image = gr.Image(label="Generated Image")
311
+ memory_status = gr.JSON(
312
+ label="Memory Status",
313
+ value=model_manager.get_memory_status()
314
+ )
315
+
316
+ def update_params(model_name):
317
+ model_config = MODELS[model_name]["parameters"]
318
+ return [
319
+ gr.Slider.update(
320
+ minimum=model_config["height"]["min"],
321
+ maximum=model_config["height"]["max"],
322
+ value=model_config["height"]["default"],
323
+ step=model_config["height"]["step"]
324
+ ),
325
+ gr.Slider.update(
326
+ minimum=model_config["width"]["min"],
327
+ maximum=model_config["width"]["max"],
328
+ value=model_config["width"]["default"],
329
+ step=model_config["width"]["step"]
330
+ ),
331
+ gr.Slider.update(
332
+ minimum=model_config["num_inference_steps"]["min"],
333
+ maximum=model_config["num_inference_steps"]["max"],
334
+ value=model_config["num_inference_steps"]["default"]
335
+ ),
336
+ gr.Slider.update(
337
+ minimum=model_config["guidance_scale"]["min"],
338
+ maximum=model_config["guidance_scale"]["max"],
339
+ value=model_config["guidance_scale"]["default"]
340
+ )
341
+ ]
342
+
343
+ def generate(model_name, prompt_text, h, w, steps, guide_scale, ref_img):
344
+ response = generate_image(
345
+ model_name=model_name,
346
+ prompt=prompt_text,
347
+ height=h,
348
+ width=w,
349
+ num_inference_steps=steps,
350
+ guidance_scale=guide_scale,
351
+ reference_image=ref_img
352
+ )
353
+ return Image.open(BytesIO(base64.b64decode(response["image_base64"])))
354
+
355
+ model_dropdown.change(
356
+ update_params,
357
+ inputs=[model_dropdown],
358
+ outputs=[height, width, num_steps, guidance]
359
+ )
360
+
361
+ generate_btn.click(
362
+ generate,
363
+ inputs=[
364
+ model_dropdown,
365
+ prompt,
366
+ height,
367
+ width,
368
+ num_steps,
369
+ guidance,
370
+ reference_image
371
+ ],
372
+ outputs=[output_image]
373
+ )
374
+
375
+ unload_btn.click(
376
+ lambda: [model_manager.unload_current_model(), model_manager.get_memory_status()],
377
+ outputs=[memory_status]
378
+ )
379
+
380
+ return interface
381
+
382
+ if __name__ == "__main__":
383
+ import uvicorn
384
+ from threading import Thread
385
+
386
+ # Launch Gradio interface
387
+ interface = create_gradio_interface()
388
+ gradio_thread = Thread(
389
+ target=interface.launch,
390
+ kwargs={
391
+ "server_name": "0.0.0.0",
392
+ "server_port": 7860,
393
+ "share": True
394
+ }
395
+ )
396
+ gradio_thread.start()
397
+
398
+ # Launch FastAPI
399
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.104.0
2
+ uvicorn>=0.24.0
3
+ python-multipart>=0.0.6
4
+ gradio>=4.11.0
5
+ torch>=2.1.0
6
+ torchvision>=0.16.0
7
+ diffusers>=0.24.0
8
+ transformers>=4.36.0
9
+ accelerate>=0.25.0
10
+ safetensors>=0.4.0
11
+ xformers>=0.0.22.post7 # Optional but recommended for memory efficiency
12
+ pillow>=10.0.0
13
+ numpy>=1.24.0
14
+ packaging>=23.2
15
+ pydantic>=2.5.0
16
+ tqdm>=4.66.0
17
+ typing-extensions>=4.8.0
18
+ python-dotenv>=1.0.0