wsi-generator / app.py
kaveh's picture
added .style(container=False)
3ae3bd3
raw
history blame
1.47 kB
from diffusers import DiffusionPipeline
import gradio as gr
import sys
generator = DiffusionPipeline.from_pretrained("kaveh/wsi_generator")
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
sys.stdout = Logger("output.log")
def read_logs():
sys.stdout.flush()
with open("output.log", "r") as f:
return f.read()
def generate(n_samples=1, progress=gr.Progress()):
images = []
for i in range(n_samples):
image = generator().images[0]
images.append(image)
return images
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
n_s = gr.Slider(1, 4, label='Number of Samples', value=1, step=1.0, show_label=True).style(container=False)
btn = gr.Button("Generate image").style(full_width=False)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto", preview=True)
btn.click(generate, n_s, gallery)
logs = gr.Textbox().style(container=False)
demo.load(read_logs, None, logs, every=1)
demo.launch()