File size: 2,066 Bytes
46db6b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from flask import Flask, jsonify, request, send_file
from flask_cors import CORS
import torch
from diffusers import StableDiffusion3Pipeline
import numpy as np
import random
import io
from PIL import Image

# Initialize the Flask app
myapp = Flask(__name__)
CORS(myapp)  # Enable CORS if needed

# Load the model
device = "cpu"
dtype = torch.float16

repo = "stabilityai/stable-diffusion-3-medium-diffusers"
pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=dtype).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1344

@app.route('/')
def home():
    return "Welcome to the Stable Diffusion 3 Image Generation API!"  # Basic home response

@app.route('/generate_image', methods=['POST'])
def generate_image():
    data = request.json
    
    # Get inputs from request JSON
    prompt = data.get('prompt', '')
    negative_prompt = data.get('negative_prompt', None)
    seed = data.get('seed', 0)
    randomize_seed = data.get('randomize_seed', True)
    width = data.get('width', 1024)
    height = data.get('height', 1024)
    guidance_scale = data.get('guidance_scale', 5.0)
    num_inference_steps = data.get('num_inference_steps', 28)
    
    # Randomize seed if requested
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    # Generate the image
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        generator=generator
    ).images[0]
    
    # Save the image to a byte array
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format='PNG')
    img_byte_arr.seek(0)  # Move the pointer to the start of the byte array
    
    # Return the image as a response
    return send_file(img_byte_arr, mimetype='image/png')

# Add this block to make sure your app runs when called
if __name__ == "__main__":
    myapp.run(host='0.0.0.0', port=7860)  # Run the Flask app on port 7860