Geek7 commited on
Commit
46db6b6
·
verified ·
1 Parent(s): b2ec85b

Create myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +68 -0
myapp.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify, request, send_file
2
+ from flask_cors import CORS
3
+ import torch
4
+ from diffusers import StableDiffusion3Pipeline
5
+ import numpy as np
6
+ import random
7
+ import io
8
+ from PIL import Image
9
+
10
+ # Initialize the Flask app
11
+ myapp = Flask(__name__)
12
+ CORS(myapp) # Enable CORS if needed
13
+
14
+ # Load the model
15
+ device = "cpu"
16
+ dtype = torch.float16
17
+
18
+ repo = "stabilityai/stable-diffusion-3-medium-diffusers"
19
+ pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=dtype).to(device)
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ MAX_IMAGE_SIZE = 1344
23
+
24
+ @app.route('/')
25
+ def home():
26
+ return "Welcome to the Stable Diffusion 3 Image Generation API!" # Basic home response
27
+
28
+ @app.route('/generate_image', methods=['POST'])
29
+ def generate_image():
30
+ data = request.json
31
+
32
+ # Get inputs from request JSON
33
+ prompt = data.get('prompt', '')
34
+ negative_prompt = data.get('negative_prompt', None)
35
+ seed = data.get('seed', 0)
36
+ randomize_seed = data.get('randomize_seed', True)
37
+ width = data.get('width', 1024)
38
+ height = data.get('height', 1024)
39
+ guidance_scale = data.get('guidance_scale', 5.0)
40
+ num_inference_steps = data.get('num_inference_steps', 28)
41
+
42
+ # Randomize seed if requested
43
+ if randomize_seed:
44
+ seed = random.randint(0, MAX_SEED)
45
+
46
+ # Generate the image
47
+ generator = torch.Generator().manual_seed(seed)
48
+ image = pipe(
49
+ prompt=prompt,
50
+ negative_prompt=negative_prompt,
51
+ guidance_scale=guidance_scale,
52
+ num_inference_steps=num_inference_steps,
53
+ width=width,
54
+ height=height,
55
+ generator=generator
56
+ ).images[0]
57
+
58
+ # Save the image to a byte array
59
+ img_byte_arr = io.BytesIO()
60
+ image.save(img_byte_arr, format='PNG')
61
+ img_byte_arr.seek(0) # Move the pointer to the start of the byte array
62
+
63
+ # Return the image as a response
64
+ return send_file(img_byte_arr, mimetype='image/png')
65
+
66
+ # Add this block to make sure your app runs when called
67
+ if __name__ == "__main__":
68
+ myapp.run(host='0.0.0.0', port=7860) # Run the Flask app on port 7860