File size: 5,674 Bytes
fed8daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
from huggingface_hub import InferenceClient
from io import BytesIO
from PIL import Image

# Initialize the Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Initialize the InferenceClient with your Hugging Face token
HF_TOKEN = os.environ.get("HF_TOKEN")  # Ensure to set your Hugging Face token in the environment
client = InferenceClient(token=HF_TOKEN)

@app.route('/')
def home():
    return "Welcome to the Image Background Remover!"

# Simple content moderation function
def is_prompt_explicit(prompt):
    explicit_keywords = ["sexual", "nudity", "erotic", "explicit", "porn", "pornographic", "xxx", "hentai", "fetish", "sex", "sensual", "nude", "strip", "stripping", "adult", "lewd", "provocative", "obscene", "vulgar", "intimacy", "intimate", "lust", "arouse", "seductive", "seduction", "kinky", "bdsm", "dominatrix", "bondage", "hardcore", "softcore", "topless", "bottomless", "threesome", "orgy", "incest", "taboo", "masturbation", "genital", "penis", "vagina", "breast", "boob", "nipple", "butt", "anal", "oral", "ejaculation", "climax", "moan", "foreplay", "intercourse", "naked", "exposed", "suicide", "self-harm", "overdose", "poison", "hang", "end life", "kill myself", "noose", "depression", "hopeless", "worthless", "die", "death", "harm myself"]  # Add more keywords as needed
    for keyword in explicit_keywords:
        if keyword.lower() in prompt.lower():
            return True
    return False

# Function to generate an image from a text prompt
def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/sd-3.5", num_inference_steps=50, guidance_scale=7.5, seed=None):
    try:
        # Generate the image using Hugging Face's inference API with additional parameters
        image = client.text_to_image(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            model=model,
            num_inference_steps=num_inference_steps,  # Control the number of inference steps
            guidance_scale=guidance_scale,  # Control the guidance scale
            seed=seed  # Control the seed for reproducibility
        )
        return image  # Return the generated image
    except Exception as e:
        print(f"Error generating image: {str(e)}")
        return None

# Function to refine an image using the refiner model
def refine_image(image, prompt, negative_prompt=None, model="stabilityai/stable-diffusion-xl-refiner-1.0", num_inference_steps=50, guidance_scale=7.5):
    try:
        # Use Hugging Face's image-to-image API to refine the image
        refined_image = client.image_to_image(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=image,
            model=model,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale
        )
        return refined_image
    except Exception as e:
        print(f"Error refining image: {str(e)}")
        return None

@app.route('/generate_image', methods=['POST'])
def generate_api():
    data = request.get_json()

    # Extract required fields from the request
    prompt = data.get('prompt', '')
    negative_prompt = data.get('negative_prompt', None)
    height = data.get('height', 1024)  # Default height
    width = data.get('width', 720)  # Default width
    num_inference_steps = data.get('num_inference_steps', 50)  # Default number of inference steps
    guidance_scale = data.get('guidance_scale', 7.5)  # Default guidance scale
    model_name = data.get('model', 'stabilityai/sd-3.5')  # Base model
    refiner_model_name = 'stabilityai/sd-xl-refiner-1.0'  # Refiner model
    seed = data.get('seed', None)  # Seed for reproducibility, default is None

    if not prompt:
        return jsonify({"error": "Prompt is required"}), 400

    try:
        # Check for explicit content
        if is_prompt_explicit(prompt):
            # Return the pre-defined "thinkgood.png" image
            return send_file(
                "thinkgood.jpeg",
                mimetype='image/png',
                as_attachment=False,
                download_name='thinkgood.png'
            )

        # Step 1: Generate the base image
        base_image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed)
        
        if not base_image:
            return jsonify({"error": "Failed to generate base image"}), 500

        # Step 2: Refine the image with the refiner model
        refined_image = refine_image(base_image, prompt, negative_prompt, refiner_model_name, num_inference_steps, guidance_scale)
        
        if not refined_image:
            return jsonify({"error": "Failed to refine image"}), 500

        # Save the refined image to a BytesIO object
        img_byte_arr = BytesIO()
        refined_image.save(img_byte_arr, format='PNG')  # Convert the image to PNG
        img_byte_arr.seek(0)  # Move to the start of the byte stream

        # Send the refined image as a response
        return send_file(
            img_byte_arr,
            mimetype='image/png',
            as_attachment=False,  # Send the file inline
            download_name='refined_image.png'  # File name for download
        )
    except Exception as e:
        print(f"Error in generate_api: {str(e)}")  # Log the error
        return jsonify({"error": str(e)}), 500

# Add this block to make sure your app runs when called
if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)  # Run directly if needed for testing