File size: 4,046 Bytes
aac338f
fca3535
0824d60
aac338f
fca3535
aac338f
76e7d38
fca3535
 
c61580a
76e7d38
0824d60
4ff15b7
fca3535
76e7d38
 
fca3535
7e9b981
 
 
fca3535
0824d60
76e7d38
7e9b981
76e7d38
fca3535
 
 
 
 
 
76e7d38
 
 
fca3535
7e9b981
c61580a
7e9b981
 
0824d60
fca3535
7e1de43
0824d60
76e7d38
 
0824d60
fca3535
0824d60
76e7d38
7e9b981
 
76e7d38
 
7e9b981
fca3535
 
 
 
76e7d38
fca3535
76e7d38
bbc9212
76e7d38
 
fca3535
76e7d38
 
 
 
 
fca3535
 
76e7d38
 
fca3535
7e9b981
bbc9212
fca3535
7e9b981
a51b33a
0824d60
fca3535
3b0a2ee
 
76e7d38
fca3535
76e7d38
 
3b0a2ee
 
 
fca3535
3b0a2ee
 
fca3535
c61580a
3b0a2ee
 
fca3535
3b0a2ee
7e9b981
76e7d38
1a2b6b0
76e7d38
fca3535
3b0a2ee
fca3535
76e7d38
 
fca3535
76e7d38
 
fca3535
7e9b981
76e7d38
7e9b981
 
3b0a2ee
fca3535
76e7d38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
import asyncio
import os
from random import randint
from threading import RLock
from pathlib import Path
from all_models import models
from externalmod import gr_Interface_load, randomize_seed

# Create a lock for thread safety
lock = RLock()

# Load Hugging Face token from environment variable
HF_TOKEN = os.getenv("HF_TOKEN")

# Function to load models with optimized settings
def load_fn(models):
    global models_load
    models_load = {}

    for model in models:
        if model not in models_load:
            try:
                print(f"Loading model: {model}")
                m = gr_Interface_load(
                    f'models/{model}',
                    hf_token=HF_TOKEN,
                    torch_dtype=torch.float16  # Reduce memory usage
                )
                m.enable_model_cpu_offload()  # Offload to CPU when not in use
                models_load[model] = m
            except Exception as e:
                print(f"Error loading model {model}: {e}")
                models_load[model] = None

print("Loading models...")
load_fn(models)
print("Models loaded successfully.")

# Constants
num_models = 1
starting_seed = randint(1941, 2024)
MAX_SEED = 3999999999
inference_timeout = 600

# Update UI components
def extend_choices(choices):
    return choices[:num_models] + ['NA'] * (num_models - len(choices))

def update_imgbox(choices):
    choices_extended = extend_choices(choices)
    return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]

# Async inference function
async def infer(model_str, prompt, seed=1):
    if model_str not in models_load or models_load[model_str] is None:
        print(f"Model {model_str} is unavailable.")
        return None

    kwargs = {"seed": seed}
    try:
        print(f"Running inference for model: {model_str} with prompt: '{prompt}'")
        result = await asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)

        if result:
            with lock:
                png_path = "image.png"
                result.save(png_path)
                return str(Path(png_path).resolve())
    except torch.cuda.OutOfMemoryError:
        print(f"CUDA memory error for {model_str}. Try reducing image size.")
    except Exception as e:
        print(f"Error during inference for {model_str}: {e}")

    return None

# Synchronous wrapper
def gen_fnseed(model_str, prompt, seed=1):
    if model_str == 'NA':
        return None

    try:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        result = loop.run_until_complete(infer(model_str, prompt, seed))
    except Exception as e:
        print(f"Error generating image for {model_str}: {e}")
        result = None
    finally:
        loop.close()

    return result

# Gradio UI
print("Creating Gradio interface...")
with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    gr.HTML("<center><h1>Compare-6</h1></center>")

    with gr.Tab('Compare-6'):
        txt_input = gr.Textbox(label='Your prompt:', lines=4)
        gen_button = gr.Button('Generate up to 6 images')
        seed = gr.Slider(label="Seed (0 to MAX)", minimum=0, maximum=MAX_SEED, value=starting_seed)
        seed_rand = gr.Button("Randomize Seed 🎲")

        seed_rand.click(randomize_seed, None, [seed], queue=False)

        output = [gr.Image(label=m) for m in models[:num_models]]
        current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]

        for m, o in zip(current_models, output):
            gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o], queue=False)

        with gr.Accordion('Model selection'):
            model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models')
            model_choice.change(update_imgbox, model_choice, output)
            model_choice.change(extend_choices, model_choice, current_models)

demo.queue(default_concurrency_limit=20, max_size=50)  # Adjusted for better stability
demo.launch(show_api=False)