File size: 3,722 Bytes
988597d
4580970
 
0de48e4
4580970
0de48e4
 
 
4580970
0de48e4
4580970
0de48e4
4580970
 
0de48e4
 
 
 
4580970
0de48e4
4580970
0de48e4
4580970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de48e4
4580970
 
 
 
 
 
 
 
 
 
 
 
 
0de48e4
 
4580970
 
 
988597d
4580970
 
 
 
0de48e4
 
 
4580970
da2958a
4580970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de48e4
 
4580970
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import asyncio
import os
from random import randint
from threading import RLock
from all_models import models
from externalmod import gr_Interface_load, randomize_seed

# Lock to prevent concurrent access issues
lock = RLock()
HF_TOKEN = os.getenv("HF_TOKEN", None)  # Hugging Face token for private models

# Load models
def load_models(models):
    global models_load
    models_load = {}
    
    for model in models:
        if model not in models_load:
            try:
                models_load[model] = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
            except Exception as error:
                print(f"Error loading model {model}: {error}")
                models_load[model] = None  # Handle failed model load

load_models(models)

# Constants
NUM_MODELS = 9
DEFAULT_MODELS = models[:NUM_MODELS]
INFERENCE_TIMEOUT = 600
MAX_SEED = 666666666
starting_seed = randint(666666000, MAX_SEED)

# Async inference function
async def infer(model_str, prompt, seed=1, timeout=INFERENCE_TIMEOUT):
    if model_str not in models_load or models_load[model_str] is None:
        return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png"
    
    try:
        result = await asyncio.wait_for(
            asyncio.to_thread(models_load[model_str].fn, prompt=prompt, seed=seed, token=HF_TOKEN), 
            timeout=timeout
        )
        return result if result else "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png"
    except asyncio.TimeoutError:
        print(f"Timeout error: {model_str}")
    except Exception as e:
        print(f"Error in inference: {e}")
    return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png"

# Synchronous wrapper
def generate_image(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None
    
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        return loop.run_until_complete(infer(model_str, prompt, seed))
    except Exception as e:
        print(f"Error generating image: {e}")
        return "https://huggingface.co/spaces/Yntec/ToyWorld/resolve/main/error.png"
    finally:
        loop.close()

# Gradio UI
with gr.Blocks(theme='Yntec/HaleyCH_Theme_craiyon') as demo:
    gr.HTML("""<center><img src='https://huggingface.co/spaces/Yntec/open-craiyon/resolve/main/open_craiyon.png' height='79'></center>""")
    
    with gr.Tab('🖍️ AI Image Generator 🖍️'):
        txt_input = gr.Textbox(label='Enter your prompt:', lines=4)
        gen_button = gr.Button('Generate Image 🖍️')
        
        seed_slider = gr.Slider(label="Seed (for reproducibility)", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
        random_seed_btn = gr.Button("Randomize Seed 🎲")
        
        random_seed_btn.click(randomize_seed, None, [seed_slider])
        
        output_images = [gr.Image(label=m) for m in DEFAULT_MODELS]
        model_inputs = [gr.Textbox(m, visible=False) for m in DEFAULT_MODELS]
        
        for model, img_output in zip(model_inputs, output_images):
            gen_button.click(generate_image, inputs=[model, txt_input, seed_slider], outputs=[img_output])
        
        gr.Accordion("Model Selection", open=False):
            model_choice = gr.CheckboxGroup(models, label="Select Models", value=DEFAULT_MODELS)
            model_choice.change(lambda selected: [gr.Image(visible=m in selected) for m in models], inputs=[model_choice], outputs=[output_images])
    
    gr.HTML("""<p>Check out more models at <a href='https://huggingface.co/spaces/Yntec/ToyWorld'>Toy World</a>!</p>""")

demo.queue(default_concurrency_limit=200, max_size=200)
demo.launch(show_api=False, max_threads=400)