File size: 11,642 Bytes
e547b24
 
 
 
 
 
7abf720
e547b24
 
2f294b2
 
7abf720
e547b24
 
 
2f294b2
7abf720
c84bbeb
2f294b2
 
 
 
e547b24
2f294b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7abf720
8cf6f97
7abf720
 
2f294b2
 
 
 
92cbf13
c84bbeb
2f294b2
 
 
 
 
 
 
 
 
 
7abf720
2f294b2
7abf720
e547b24
2f294b2
7abf720
f869ea1
 
 
 
2f294b2
 
119e558
e547b24
 
2f294b2
e547b24
2f294b2
 
 
 
 
7abf720
2f294b2
e547b24
7abf720
2f294b2
 
 
 
 
 
 
 
 
 
 
 
7abf720
2f294b2
 
 
 
 
 
7abf720
2f294b2
 
 
 
 
7abf720
 
 
 
 
 
 
 
2f294b2
 
 
 
7abf720
2f294b2
7abf720
2f294b2
7abf720
2f294b2
 
 
 
 
 
7abf720
2f294b2
 
 
 
7abf720
2f294b2
 
 
 
 
 
 
 
 
 
 
 
7abf720
2f294b2
 
 
 
7abf720
2f294b2
7abf720
2f294b2
 
 
7abf720
2f294b2
 
 
 
 
7abf720
2f294b2
7abf720
2f294b2
7abf720
2f294b2
7abf720
2f294b2
 
 
c84bbeb
2f294b2
7abf720
2f294b2
 
e547b24
7abf720
e547b24
02f8cfa
c84bbeb
02f8cfa
 
73f7edc
c84bbeb
 
 
7abf720
 
 
2f294b2
e547b24
 
7abf720
c84bbeb
7abf720
2f294b2
02f8cfa
 
c84bbeb
 
7abf720
 
c84bbeb
 
7abf720
c84bbeb
 
 
8cf6f97
 
7abf720
2f294b2
e547b24
02f8cfa
c84bbeb
2f294b2
 
02f8cfa
c84bbeb
2f294b2
 
 
 
7abf720
2f294b2
 
7abf720
2f294b2
7abf720
2f294b2
 
e547b24
7abf720
2f294b2
 
 
7abf720
 
2f294b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image, UnidentifiedImageError
from deep_translator import GoogleTranslator
import json
import uuid
from urllib.parse import quote
import traceback

# Project by Nymbo

# --- Constants ---
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
API_TOKEN = os.getenv("HF_READ_TOKEN")
if not API_TOKEN:
    print("WARNING: HF_READ_TOKEN environment variable not set. API calls may fail.")
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
timeout = 100 # seconds for API call timeout

IMAGE_DIR = "temp_generated_images" # Directory to store temporary images
ARINTELLI_REDIRECT_BASE = "https://arintelli.com/app/" # Your redirector URL

# --- Ensure temporary directory exists ---
try:
    os.makedirs(IMAGE_DIR, exist_ok=True)
    print(f"Confirmed temporary image directory exists: {IMAGE_DIR}")
except OSError as e:
    print(f"ERROR: Could not create directory {IMAGE_DIR}: {e}")
    # This is critical, so raise an error to prevent app start if dir fails
    raise gr.Error(f"Fatal Error: Cannot create temporary image directory: {e}")

# --- Get Absolute Path for allowed_paths ---
# This needs to be done *before* calling launch()
absolute_image_dir = os.path.abspath(IMAGE_DIR)
print(f"Absolute path for allowed_paths: {absolute_image_dir}")

# Function to query the API and return the generated image and download link
def query(prompt, negative_prompt="", steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
    # Removed sampler and strength as they are not explicitly used in the payload below
    # Note: If the API endpoint *does* support sampler/strength, add them back to the payload
    if not prompt or not prompt.strip():
        print("Empty prompt received.")
        # Return None for image and an informative message for the HTML component
        return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"

    key = random.randint(0, 999)
    print(f"\n--- Generation {key} Started ---")

    # Translation
    try:
        translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
    except Exception as e:
        translated_prompt = prompt # Fallback to original if translation fails

    # Add suffix to prompt
    final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
    print(f'Generation {key} prompt: {final_prompt}')

    # Prepare the payload for the API call
    payload = {
        "inputs": final_prompt,
        "parameters": {
            "negative_prompt": negative_prompt,
            "num_inference_steps": steps,
            "guidance_scale": cfg_scale,
            "seed": seed if seed != -1 else random.randint(1, 1000000000),
            "width": width,
            "height": height,
        }
    }

    # API Call Section
    try:
        if not headers:
             print("WARNING: Authorization header is missing (HF_READ_TOKEN not set?)")
             return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"

        response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
        response.raise_for_status()

        image_bytes = response.content
        if not image_bytes or len(image_bytes) < 100:
             print(f"Error: Received empty or very small response content (length: {len(image_bytes)}). Potential API issue.")
             return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"

        try:
            image = Image.open(io.BytesIO(image_bytes))
        except UnidentifiedImageError as img_err:
             print(f"Error: Could not identify or open image from API response bytes: {img_err}")
             return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"

        # --- Save image and create download link ---
        filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
        save_path = os.path.join(IMAGE_DIR, filename)
        absolute_save_path = os.path.abspath(save_path)

        try:
            image.save(save_path, "PNG")

            if os.path.exists(save_path):
                file_size = os.path.getsize(save_path)
                if file_size < 100:
                     print(f"WARNING: Saved file {save_path} is very small ({file_size} bytes). May indicate an issue.")
            else:
                print(f"CRITICAL ERROR: File NOT found at {save_path} (Absolute: {absolute_save_path}) immediately after saving!")
                return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"

            # Determine space name (adjust logic if API_URL format differs)
            try:
                space_name = "greendra-flux-1-schnell-serverless"
                # A more robust way might involve getting the space ID from env vars if available
            except IndexError:
                 print("WARNING: Could not reliably determine space name from API_URL. Using a default.")
                 space_name = "unknown-flux-space" # Provide a fallback

            relative_file_url = f"/gradio_api/file={save_path}"

            encoded_file_url = quote(relative_file_url)
            arintelli_url = f"{ARINTELLI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"
            print(f"{arintelli_url}")

            # Use Gradio's primary button style for the link
            download_html = (
                f'<div style="text-align: center; margin-top: 15px;">' # Added margin-top
                f'<a href="{arintelli_url}" target="_blank" class="gr-button gr-button-lg gr-button-primary">'
                f'Download Image'
                f'</a>'
                f'</div>'
            )

            print(f"--- Generation {key} Done ---")
            return image, download_html

        except (OSError, IOError) as save_err:
            print(f"CRITICAL ERROR: Failed to save image to {save_path} (Absolute: {absolute_save_path}): {save_err}")
            traceback.print_exc()
            return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file. Details: {save_err}</p>"
        except Exception as e:
            print(f"Error during link creation or unexpected save issue: {e}")
            traceback.print_exc()
            return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"

    # --- Exception Handling for API Call ---
    except requests.exceptions.Timeout:
        print(f"Error: Request timed out after {timeout} seconds.")
        return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
    except requests.exceptions.HTTPError as e:
        status_code = e.response.status_code
        error_text = e.response.text
        try:
            error_data = e.response.json()
            error_text = error_data.get('error', error_text)
            if isinstance(error_text, dict) and 'message' in error_text:
                 error_text = error_text['message']
        except json.JSONDecodeError:
            pass

        print(f"Error: Failed API call. Status: {status_code}, Response: {error_text}")

        if status_code == 503:
             estimated_time = error_data.get("estimated_time") if 'error_data' in locals() and isinstance(error_data, dict) else None
             if estimated_time:
                  error_message = f"Model is loading (503), please wait. Est. time: {estimated_time:.1f}s. Try again."
             else:
                  error_message = f"Service unavailable (503). Model might be loading or down. Try again later."
        elif status_code == 400:
             error_message = f"Bad Request (400): Check parameters. API Error: {error_text}"
        elif status_code == 422:
             error_message = f"Validation Error (422): Input invalid. API Error: {error_text}"
        elif status_code == 401 or status_code == 403:
             error_message = f"Authorization Error ({status_code}): Check your API Token (HF_READ_TOKEN)."
        else:
            error_message = f"API Error: {status_code}. Details: {error_text}"

        return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        traceback.print_exc()
        return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred: {e}</p>"


# CSS to style the app
css = """
#app-container {
    max-width: 800px;
    margin-left: auto;
    margin-right: auto;
}
textarea:focus {
    background: #0d1117 !important;
}
#download-link-container p {
    margin-top: 10px;
    font-size: 0.9em;
}
"""

# Build the Gradio UI with Blocks
with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
    gr.HTML("<center><h1>FLUX.1-Schnell</h1></center>")

    with gr.Column(elem_id="app-container"):
        with gr.Row():
            with gr.Column(elem_id="prompt-container"):
                with gr.Row():
                    text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=2, elem_id="prompt-text-input")

                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input")
                        with gr.Row():
                            width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
                            height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
                        steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=30, step=1) # Default updated based on query function default
                        cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1) # Label updated
                        # Removed strength and sampler sliders as they are not passed to query
                        seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")

        with gr.Row():
            text_button = gr.Button("Run", variant='primary', elem_id="gen-button")

        # --- Output Components ---
        with gr.Row():
            image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
        with gr.Row():
             # HTML component to display status messages or the download link
             download_link_display = gr.HTML(elem_id="download-link-container")

        # Bind the button to the query function
        text_button.click(
            query,
            # Ensure inputs match the query function definition
            inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
            # Outputs go to the image and HTML components
            outputs=[image_output, download_link_display]
        )

# Launch the Gradio app with allowed_paths
print("Starting Gradio app...")
app.launch(
    show_api=False,
    share=False,
    allowed_paths=[absolute_image_dir] # Added allowed_paths
)