from diffusers import AutoPipelineForImage2Image import torch import os import numpy as np from PIL import Image from diffusers.utils import load_image, make_image_grid from flask import Flask, request, jsonify, send_file from flask_cors import CORS import io # Set environment variable to avoid fragmentation os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Clear any unused GPU memory torch.cuda.empty_cache() app = Flask(__name__) CORS(app) print('loading models...') # Load the image-to-image pipeline from Hugging Face pipe = AutoPipelineForImage2Image.from_pretrained("RunDiffusion/Juggernaut-X-v10", torch_dtype=torch.float16).to("cuda") pipe.enable_xformers_memory_efficient_attention() pipe.enable_vae_tiling() # Improve performance on large images pipe.enable_vae_slicing() # Improve performance on large batches print('loaded models...') @app.route('/') def hello(): return {"Goes Wrong": "Keeping it real"} @app.route('/run_inference', methods=['POST']) def run_inference(): data = request.get_json() if 'url' not in data: return jsonify({"error": "No imageurl provided"}), 400 # base64_image = data['base64_image'] prompt = data.get('prompt', 'fleece hoodie, front zip, abstract pattern, GAP logo, high quality, photo') negative_prompt = data.get('negative_prompt', 'low quality, bad quality, sketches, hanger') guidance_scale = float(data.get('guidance_scale', 7)) num_images = int(data.get('num_images', 2)) url = data.get('url', 'https://storage.googleapis.com/sketch-bucket/dresstest2.PNG') sketch = load_image(url) print(f'Loaded image URL: {url}') # testing # prompt = "long waist dress, puffed sleeves, fringes on sleeve and hem, high quality, photo" # negative_prompt = "low quality, bad quality, sketches, hanger" # guidance_scale = 7 with torch.inference_mode(): images = pipe( prompt=prompt, negative_prompt=negative_prompt, image=sketch, num_inference_steps=35, guidance_scale=guidance_scale, strength=0.5, generator=torch.manual_seed(69), num_images_per_prompt=num_images, ).images grid = make_image_grid(images, rows=1, cols=num_images) # images[0].save('output.png') # Save the generated grid to a BytesIO object img_byte_arr = io.BytesIO() grid.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return send_file(img_byte_arr, mimetype='image/png') if __name__ == '__main__': app.run(debug=True)