File size: 2,594 Bytes
e88b277
5a0eb94
e88b277
 
 
46db6b6
747ac60
a180961
46db6b6
 
 
e88b277
d3ba12f
46db6b6
ded3931
e88b277
45af9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46db6b6
cccc096
46db6b6
 
e88b277
ded3931
46db6b6
e88b277
 
 
 
a180961
e88b277
 
 
 
46db6b6
 
e88b277
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from flask import Flask, request, jsonify
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os



# Initialize the Flask app
myapp = Flask(__name__)

# Load the Diffusion pipeline
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to("cpu")

@myapp.route('/')
def index():
    return '''
    <html>
        <body>
            <h1>Welcome to the Image Generation API!</h1>
            <form id="input-form">
                <label for="prompt">Enter your prompt:</label><br>
                <input type="text" id="prompt" name="prompt"><br><br>
                <button type="submit">Generate Image</button>
            </form>
            <div id="spinner" style="display:none;">Generating image, please wait...</div>
            <div id="result"></div>
            <script>
                document.getElementById('input-form').onsubmit = async (e) => {
                    e.preventDefault();
                    document.getElementById('spinner').style.display = 'block';
                    const prompt = document.getElementById('prompt').value;

                    const response = await fetch('/generate_image', {
                        method: 'POST',
                        headers: { 'Content-Type': 'application/json' },
                        body: JSON.stringify({ prompt })
                    });

                    const data = await response.json();
                    document.getElementById('spinner').style.display = 'none';
                    if (response.ok) {
                        document.getElementById('result').innerHTML = `<h2>Image Generated:</h2><img src="${data.image_path}" alt="Generated Image">`;
                    } else {
                        document.getElementById('result').innerText = 'Error generating image: ' + data.error;
                    }
                };
            </script>
        </body>
    </html>
    '''

@myapp.route('/generate_image', methods=['POST'])
def generate_image():
    data = request.json
    prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k')

    # Generate the image
    image = pipe(prompt).images[0]

    # Convert to PIL Image and save
    pil_image = Image.fromarray(image.numpy())
    output_path = f"{prompt.replace(' ', '_')}.png"  # Create a file name based on the prompt
    pil_image.save(output_path)

    # Return the path to the generated image
    return jsonify({'image_path': output_path})

if __name__ == "__main__":
    # Set the host and port
    myapp.run(host='0.0.0.0', port=7860)