from flask import Flask, request, jsonify, send_file from flask_cors import CORS import os from all_models import models from externalmod import gr_Interface_load import asyncio from threading import RLock # Initialize Flask app and enable CORS app = Flask(__name__) CORS(app) lock = RLock() HF_TOKEN = os.environ.get("HF_TOKEN") # Load models into a global dictionary models_load = {} def load_fn(models): global models_load models_load = {} for model in models: if model not in models_load.keys(): try: m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN) models_load[model] = m except Exception as error: print(f"Error loading model {model}: {error}") models_load[model] = None # Handle model loading failures load_fn(models) inference_timeout = 600 async def infer(model_str, prompt, seed=1, timeout=inference_timeout): kwargs = {"seed": seed} try: task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)) await asyncio.sleep(0) result = await asyncio.wait_for(task, timeout=timeout) if task.done() and result is not None: with lock: png_path = "image.png" result.save(png_path) return png_path # Return the path of the saved image except Exception as e: print(f"Inference error for model {model_str}: {e}") # Log the error message return None @app.route('/generate', methods=['POST']) def generate(): data = request.json model_str = data.get('model') prompt = data.get('prompt') seed = data.get('seed', 1) print(f"Received request for model: '{model_str}', prompt: '{prompt}', seed: {seed}") if model_str not in models_load or models_load[model_str] is None: print(f"Model not found in models_load: {model_str}. Available models: {models_load.keys()}") return jsonify({"error": "Model not found or not loaded"}), 404 image_path = asyncio.run(infer(model_str, prompt, seed, inference_timeout)) if image_path is not None: return send_file(image_path, mimetype='image/png') print("Image generation failed for:", model_str) # Log failure reason return jsonify({"error": "Image generation failed"}), 500 if __name__ == '__main__': app.run(debug=True)