|
from flask import Flask, request, jsonify |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import os |
|
|
|
|
|
|
|
|
|
myapp = Flask(__name__) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to("cpu") |
|
|
|
@myapp.route('/') |
|
def index(): |
|
return ''' |
|
<html> |
|
<body> |
|
<h1>Welcome to the Image Generation API!</h1> |
|
<form id="input-form"> |
|
<label for="prompt">Enter your prompt:</label><br> |
|
<input type="text" id="prompt" name="prompt"><br><br> |
|
<button type="submit">Generate Image</button> |
|
</form> |
|
<div id="spinner" style="display:none;">Generating image, please wait...</div> |
|
<div id="result"></div> |
|
<script> |
|
document.getElementById('input-form').onsubmit = async (e) => { |
|
e.preventDefault(); |
|
document.getElementById('spinner').style.display = 'block'; |
|
const prompt = document.getElementById('prompt').value; |
|
|
|
const response = await fetch('/generate_image', { |
|
method: 'POST', |
|
headers: { 'Content-Type': 'application/json' }, |
|
body: JSON.stringify({ prompt }) |
|
}); |
|
|
|
const data = await response.json(); |
|
document.getElementById('spinner').style.display = 'none'; |
|
if (response.ok) { |
|
document.getElementById('result').innerHTML = `<h2>Image Generated:</h2><img src="${data.image_path}" alt="Generated Image">`; |
|
} else { |
|
document.getElementById('result').innerText = 'Error generating image: ' + data.error; |
|
} |
|
}; |
|
</script> |
|
</body> |
|
</html> |
|
''' |
|
|
|
@myapp.route('/generate_image', methods=['POST']) |
|
def generate_image(): |
|
data = request.json |
|
prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k') |
|
|
|
|
|
image = pipe(prompt).images[0] |
|
|
|
|
|
pil_image = Image.fromarray(image.numpy()) |
|
output_path = f"{prompt.replace(' ', '_')}.png" |
|
pil_image.save(output_path) |
|
|
|
|
|
return jsonify({'image_path': output_path}) |
|
|
|
if __name__ == "__main__": |
|
|
|
myapp.run(host='0.0.0.0', port=7860) |