File size: 2,846 Bytes
65f3fc3 3b41fd8 65f3fc3 3b41fd8 eb0bba6 65f3fc3 3b41fd8 eb0bba6 3b41fd8 eb0bba6 3b41fd8 eb0bba6 65f3fc3 3b41fd8 eb0bba6 3b41fd8 eb0bba6 3b41fd8 eb0bba6 65f3fc3 3b41fd8 65f3fc3 eb0bba6 44aff97 65f3fc3 7eec8a0 65f3fc3 3b41fd8 7eec8a0 65f3fc3 3b41fd8 65f3fc3 3b41fd8 65f3fc3 3b41fd8 65f3fc3 7eec8a0 3039308 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import asyncio
import tempfile
import os
from threading import RLock
from huggingface_hub import InferenceClient
from all_models import models # Importing models from all_models
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
inference_timeout = 600 # Set timeout for inference
# Function to dynamically load models from the "models" list
def get_model_from_name(model_name):
return model_name if model_name in models else None
# Asynchronous function to perform inference
async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
task = asyncio.create_task(
asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
)
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except (Exception, asyncio.TimeoutError) as e:
print(e)
print(f"Task timed out for model: {model}")
if not task.done():
task.cancel()
result = None
if task.done() and result is not None:
with lock:
temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
with open(temp_image.name, "wb") as f:
f.write(result) # Save the result image as a temporary file
return temp_image.name # Return the path to the saved image
return None
# Flask route for the API endpoint
@app.route('/generate_api', methods=['POST'])
def generate_api():
data = request.get_json()
# Extract required fields from the request
prompt = data.get('prompt', '')
seed = data.get('seed', 1)
model_name = data.get('model', 'prompthero/openjourney-v4') # Default to "prompthero/openjourney-v4" if not provided
if not prompt:
return jsonify({"error": "Prompt is required"}), 400
# Get the model from all_models
model = get_model_from_name(model_name)
if not model:
return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
try:
# Create a generic InferenceClient for the model
client = InferenceClient(token=HF_TOKEN) # Pass Hugging Face token if needed
# Call the async inference function
result_path = asyncio.run(infer(client, prompt, seed, model=model))
if result_path:
return send_file(result_path, mimetype='image/png') # Send back the generated image file
else:
return jsonify({"error": "Failed to generate image"}), 500
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860) # Run directly if needed for testing
|