Spaces:
Runtime error
Runtime error
import gradio as gr | |
import asyncio | |
import os | |
from random import randint | |
from threading import RLock | |
from all_models import models | |
from externalmod import gr_Interface_load, randomize_seed | |
# Lock to prevent concurrent access issues | |
lock = RLock() | |
HF_TOKEN = os.getenv("HF_TOKEN", None) # Hugging Face token for private models | |
# Load models | |
def load_models(models): | |
global models_load | |
models_load = {} | |
for model in models: | |
if model not in models_load: | |
try: | |
models_load[model] = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN) | |
except Exception as error: | |
print(f"Error loading model {model}: {error}") | |
models_load[model] = None # Handle failed model load | |
load_models(models) | |
# Constants | |
NUM_MODELS = 9 | |
DEFAULT_MODELS = models[:NUM_MODELS] | |
INFERENCE_TIMEOUT = 600 | |
MAX_SEED = 666666666 | |
starting_seed = randint(666666000, MAX_SEED) | |
# Async inference function | |
async def infer(model_str, prompt, seed=1, timeout=INFERENCE_TIMEOUT): | |
if model_str not in models_load or models_load[model_str] is None: | |
return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png" | |
try: | |
result = await asyncio.wait_for( | |
asyncio.to_thread(models_load[model_str].fn, prompt=prompt, seed=seed, token=HF_TOKEN), | |
timeout=timeout | |
) | |
return result if result else "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png" | |
except asyncio.TimeoutError: | |
print(f"Timeout error: {model_str}") | |
except Exception as e: | |
print(f"Error in inference: {e}") | |
return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png" | |
# Synchronous wrapper | |
def generate_image(model_str, prompt, seed=1): | |
if model_str == 'NA': | |
return None | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
return loop.run_until_complete(infer(model_str, prompt, seed)) | |
except Exception as e: | |
print(f"Error generating image: {e}") | |
return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png" | |
finally: | |
loop.close() | |
# Gradio UI | |
with gr.Blocks(theme='Yntec/HaleyCH_Theme_craiyon') as demo: | |
gr.HTML("""<center><img src='https://huggingface.co/spaces/Yntec/open-craiyon/resolve/main/open_craiyon.png' height='79'></center>""") | |
with gr.Tab('🖍️ AI Image Generator 🖍️'): | |
txt_input = gr.Textbox(label='Enter your prompt:', lines=4) | |
gen_button = gr.Button('Generate Image 🖍️') | |
seed_slider = gr.Slider(label="Seed (for reproducibility)", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed) | |
random_seed_btn = gr.Button("Randomize Seed 🎲") | |
random_seed_btn.click(randomize_seed, None, [seed_slider]) | |
output_images = [gr.Image(label=m) for m in DEFAULT_MODELS] | |
model_inputs = [gr.Textbox(m, visible=False) for m in DEFAULT_MODELS] | |
for model, img_output in zip(model_inputs, output_images): | |
gen_button.click(generate_image, inputs=[model, txt_input, seed_slider], outputs=[img_output]) | |
gr.Accordion("Model Selection", open=False): | |
model_choice = gr.CheckboxGroup(models, label="Select Models", value=DEFAULT_MODELS) | |
model_choice.change(lambda selected: [gr.Image(visible=m in selected) for m in models], inputs=[model_choice], outputs=[output_images]) | |
gr.HTML("""<p>Check out more models at <a href='https://huggingface.co/spaces/Yntec/ToyWorld'>Toy World</a>!</p>""") | |
demo.queue(default_concurrency_limit=200, max_size=200) | |
demo.launch(show_api=False, max_threads=400) |