fyp1 commited on
Commit
02d46b3
1 Parent(s): ea673b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -72
app.py CHANGED
@@ -1,72 +1,60 @@
1
- from flask import Flask, request, jsonify
2
- from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
3
- from PIL import Image
4
- import torch
5
- import base64
6
- from io import BytesIO
7
- from huggingface_hub import login
8
-
9
- # Authenticate with Hugging Face Hub (ensure you replace 'your_token_here')
10
- import os
11
- login(os.environ["HF_TOKEN"])
12
-
13
- # Initialize Flask app
14
- app = Flask(__name__)
15
-
16
- # Load Hugging Face pipeline components
17
- model_id = "fyp1/sketchToImage"
18
- controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16)
19
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
20
- scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler")
21
-
22
- # Initialize Stable Diffusion XL ControlNet Pipeline
23
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
24
- "stabilityai/stable-diffusion-xl-base-1.0",
25
- controlnet=controlnet,
26
- vae=vae,
27
- scheduler=scheduler,
28
- safety_checker=None,
29
- torch_dtype=torch.float16,
30
- ).to("cuda" if torch.cuda.is_available() else "cpu")
31
-
32
- @app.route("/generate", methods=["POST"])
33
- def generate_image():
34
- data = request.json
35
-
36
- # Extract prompt, sketch image (Base64), and optional parameters
37
- prompt = data.get("prompt", "A default prompt")
38
- negative_prompt = data.get("negative_prompt", "low quality, blurry, bad details")
39
- sketch_base64 = data.get("sketch", None)
40
-
41
- if not sketch_base64:
42
- return jsonify({"error": "Sketch image is required."}), 400
43
-
44
- try:
45
- # Decode and preprocess the sketch image
46
- sketch_bytes = base64.b64decode(sketch_base64)
47
- sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L") # Convert to grayscale
48
- sketch_image = sketch_image.resize((1024, 1024))
49
-
50
- # Generate the image using the pipeline
51
- with torch.no_grad():
52
- images = pipe(
53
- prompt=prompt,
54
- negative_prompt=negative_prompt,
55
- image=sketch_image,
56
- controlnet_conditioning_scale=1.0,
57
- width=1024,
58
- height=1024,
59
- num_inference_steps=30,
60
- ).images
61
-
62
- # Convert output image to Base64
63
- buffered = BytesIO()
64
- images[0].save(buffered, format="PNG")
65
- image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
66
-
67
- return jsonify({"image": image_base64})
68
- except Exception as e:
69
- return jsonify({"error": str(e)}), 500
70
-
71
- if __name__ == "__main__":
72
- app.run(host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
4
+ from PIL import Image
5
+ import torch
6
+ import base64
7
+ from io import BytesIO
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Load Hugging Face pipeline components
13
+ model_id = "fyp1/sketchToImage"
14
+ controlnet = ControlNetModel.from_pretrained(f"{model_id}/controlnet", torch_dtype=torch.float16)
15
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
16
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(f"{model_id}/scheduler")
17
+
18
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
19
+ "stabilityai/stable-diffusion-xl-base-1.0",
20
+ controlnet=controlnet,
21
+ vae=vae,
22
+ scheduler=scheduler,
23
+ safety_checker=None,
24
+ torch_dtype=torch.float16,
25
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ class GenerateRequest(BaseModel):
28
+ prompt: str
29
+ negative_prompt: str
30
+ sketch: str # Base64 encoded image
31
+
32
+ @app.post("/generate")
33
+ async def generate_image(data: GenerateRequest):
34
+ try:
35
+ # Decode and preprocess the sketch image
36
+ sketch_bytes = base64.b64decode(data.sketch)
37
+ sketch_image = Image.open(BytesIO(sketch_bytes)).convert("L") # Convert to grayscale
38
+ sketch_image = sketch_image.resize((1024, 1024))
39
+
40
+ # Generate the image using the pipeline
41
+ with torch.no_grad():
42
+ images = pipe(
43
+ prompt=data.prompt,
44
+ negative_prompt=data.negative_prompt,
45
+ image=sketch_image,
46
+ controlnet_conditioning_scale=1.0,
47
+ width=1024,
48
+ height=1024,
49
+ num_inference_steps=30,
50
+ ).images
51
+
52
+ # Convert output image to Base64
53
+ buffered = BytesIO()
54
+ images[0].save(buffered, format="PNG")
55
+ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
56
+
57
+ return {"image": image_base64}
58
+
59
+ except Exception as e:
60
+ return {"error": str(e)}