sachin commited on
Commit
6031d9b
·
1 Parent(s): 7c6107b

add -hf token

Browse files
Files changed (1) hide show
  1. server.py +43 -19
server.py CHANGED
@@ -2,31 +2,57 @@ from fastapi import FastAPI, Response
2
  from fastapi.responses import FileResponse
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
- from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
  from io import BytesIO
8
  import os
9
 
10
  app = FastAPI()
11
 
12
- # Initialize the model once when the server starts
 
 
13
  def load_model():
14
- base = "stabilityai/stable-diffusion-xl-base-1.0"
15
- repo = "ByteDance/SDXL-Lightning"
16
- ckpt = "sdxl_lightning_4step_unet.safetensors"
17
-
18
- # Load model
19
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
20
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
21
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
22
-
23
- # Configure scheduler
24
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- return pipe
 
27
 
28
- # Load model at startup
29
- pipe = load_model()
 
 
 
 
30
 
31
  @app.get("/generate")
32
  async def generate_image(prompt: str):
@@ -43,7 +69,6 @@ async def generate_image(prompt: str):
43
  image.save(buffer, format="PNG")
44
  buffer.seek(0)
45
 
46
- # Return image as response
47
  return Response(content=buffer.getvalue(), media_type="image/png")
48
 
49
  except Exception as e:
@@ -55,5 +80,4 @@ async def health_check():
55
 
56
  if __name__ == "__main__":
57
  import uvicorn
58
- uvicorn.run(app, host="0.0.0.0", port=8000)
59
-
 
2
  from fastapi.responses import FileResponse
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
+ from huggingface_hub import hf_hub_download, login
6
  from safetensors.torch import load_file
7
  from io import BytesIO
8
  import os
9
 
10
  app = FastAPI()
11
 
12
+ # Get Hugging Face token from environment variable
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+
15
  def load_model():
16
+ try:
17
+ # Login to Hugging Face if token is provided
18
+ if HF_TOKEN:
19
+ login(token=HF_TOKEN)
20
+
21
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
22
+ repo = "ByteDance/SDXL-Lightning"
23
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
24
+
25
+ # Load model with explicit error handling
26
+ unet = UNet2DConditionModel.from_config(
27
+ base,
28
+ subfolder="unet"
29
+ ).to("cuda", torch.float16)
30
+
31
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
32
+ pipe = StableDiffusionXLPipeline.from_pretrained(
33
+ base,
34
+ unet=unet,
35
+ torch_dtype=torch.float16,
36
+ variant="fp16"
37
+ ).to("cuda")
38
+
39
+ # Configure scheduler
40
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
41
+ pipe.scheduler.config,
42
+ timestep_spacing="trailing"
43
+ )
44
+
45
+ return pipe
46
 
47
+ except Exception as e:
48
+ raise Exception(f"Failed to load model: {str(e)}")
49
 
50
+ # Load model at startup with error handling
51
+ try:
52
+ pipe = load_model()
53
+ except Exception as e:
54
+ print(f"Model initialization failed: {str(e)}")
55
+ raise
56
 
57
  @app.get("/generate")
58
  async def generate_image(prompt: str):
 
69
  image.save(buffer, format="PNG")
70
  buffer.seek(0)
71
 
 
72
  return Response(content=buffer.getvalue(), media_type="image/png")
73
 
74
  except Exception as e:
 
80
 
81
  if __name__ == "__main__":
82
  import uvicorn
83
+ uvicorn.run(app, host="0.0.0.0", port=7860)