tadmztxi / myapp.py
Geek7's picture
Update myapp.py
d3ba12f verified
raw
history blame
1.04 kB
from flask import Flask, request, jsonify
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import os
# Initialize the Flask app
myapp = Flask(__name__)
# Load the Diffusion pipeline
pipe = DiffusionPipeline.from_pretrained("prompthero/openjourney-v4").to("cpu")
@myapp.route('/')
def index():
return "Welcome to the Image Generation API!"
@myapp.route('/generate_image', methods=['POST'])
def generate_image():
data = request.json
prompt = data.get('prompt', 'Astronaut in a jungle, cold color palette, muted colors, detailed, 8k')
# Generate the image
image = pipe(prompt).images[0]
# Convert to PIL Image and save
pil_image = Image.fromarray(image.numpy())
output_path = f"{prompt.replace(' ', '_')}.png" # Create a file name based on the prompt
pil_image.save(output_path)
# Return the path to the generated image
return jsonify({'image_path': output_path})
if __name__ == "__main__":
# Set the host and port
myapp.run(host='0.0.0.0', port=7860)