File size: 2,846 Bytes
65f3fc3
 
3b41fd8
65f3fc3
3b41fd8
 
 
 
eb0bba6
65f3fc3
 
 
3b41fd8
 
eb0bba6
3b41fd8
eb0bba6
3b41fd8
 
 
eb0bba6
65f3fc3
3b41fd8
 
 
 
eb0bba6
 
 
 
 
3b41fd8
 
eb0bba6
 
3b41fd8
eb0bba6
 
65f3fc3
3b41fd8
 
65f3fc3
eb0bba6
44aff97
65f3fc3
 
 
 
7eec8a0
65f3fc3
 
 
3b41fd8
7eec8a0
65f3fc3
 
3b41fd8
 
 
 
 
 
65f3fc3
3b41fd8
 
 
65f3fc3
3b41fd8
65f3fc3
 
 
 
 
 
7eec8a0
3039308
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import asyncio
import tempfile
import os
from threading import RLock
from huggingface_hub import InferenceClient
from all_models import models  # Importing models from all_models

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN")  # Hugging Face token

inference_timeout = 600  # Set timeout for inference

# Function to dynamically load models from the "models" list
def get_model_from_name(model_name):
    return model_name if model_name in models else None

# Asynchronous function to perform inference
async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
    task = asyncio.create_task(
        asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
    )
    await asyncio.sleep(0)
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
    except (Exception, asyncio.TimeoutError) as e:
        print(e)
        print(f"Task timed out for model: {model}")
        if not task.done():
            task.cancel()
        result = None
    
    if task.done() and result is not None:
        with lock:
            temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
            with open(temp_image.name, "wb") as f:
                f.write(result)  # Save the result image as a temporary file
        return temp_image.name  # Return the path to the saved image
    return None

# Flask route for the API endpoint
@app.route('/generate_api', methods=['POST'])
def generate_api():
    data = request.get_json()

    # Extract required fields from the request
    prompt = data.get('prompt', '')
    seed = data.get('seed', 1)
    model_name = data.get('model', 'prompthero/openjourney-v4')  # Default to "prompthero/openjourney-v4" if not provided

    if not prompt:
        return jsonify({"error": "Prompt is required"}), 400

    # Get the model from all_models
    model = get_model_from_name(model_name)
    if not model:
        return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400

    try:
        # Create a generic InferenceClient for the model
        client = InferenceClient(token=HF_TOKEN)  # Pass Hugging Face token if needed

        # Call the async inference function
        result_path = asyncio.run(infer(client, prompt, seed, model=model))
        if result_path:
            return send_file(result_path, mimetype='image/png')  # Send back the generated image file
        else:
            return jsonify({"error": "Failed to generate image"}), 500
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)  # Run directly if needed for testing