|
from flask import Flask, request, jsonify |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import os |
|
|
|
|
|
myapp = Flask(__name__) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to("cpu") |
|
|
|
@myapp.route('/') |
|
def index(): |
|
return "Welcome to the Image Generation API!" |
|
|
|
@myapp.route('/generate_image', methods=['POST']) |
|
def generate_image(): |
|
data = request.json |
|
prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k') |
|
|
|
|
|
image = pipe(prompt).images[0] |
|
|
|
|
|
pil_image = Image.fromarray(image.numpy()) |
|
output_path = f"{prompt.replace(' ', '_')}.png" |
|
pil_image.save(output_path) |
|
|
|
|
|
return jsonify({'image_path': output_path}) |
|
|
|
if __name__ == "__main__": |
|
|
|
myapp.run(host='0.0.0.0', port=7860) |