from flask import Flask, request, jsonify, send_file from flask_cors import CORS from gradio_client import Client from all_models import models # Import the models list app = Flask(__name__) CORS(app) # Initialize Gradio Client with the first model in the list client = Client("Geek7/mdztxi2") @app.route('/predict', methods=['POST']) def predict(): data = request.get_json() # Validate required fields if not data or 'model_str' not in data or 'prompt' not in data or 'seed' not in data: return jsonify({"error": "Missing required fields"}), 400 model_str = data['model_str'] prompt = data['prompt'] seed = data['seed'] # Check if the model_str exists in the models list if model_str not in models: return jsonify({"error": f"Model '{model_str}' is not available."}), 400 try: # Send a request to the Gradio Client and get the result result = client.predict( model_str=model_str, prompt=prompt, seed=seed, api_name="/predict" ) # Save the result to a file (assuming it returns a filepath) result_path = result # Result is already the filepath return send_file(result_path, mimetype='image/png') except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(debug=True)