Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import prediction | |
import model | |
import diffusion_loss | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
pipe = model.initialize_diffusion_model() | |
def generate(prompt, loss_function=None): | |
return prediction.predict(prompt=prompt, pipe=pipe, loss_function=loss_function) | |
def process_input(prompt, loss_function, button): | |
if button: | |
if loss_function is None or loss_function == "No Loss": | |
return generate(prompt, loss_function=None) | |
elif loss_function == "Blue Channel": | |
return generate(prompt, loss_function=diffusion_loss.blue_channel) | |
elif loss_function == "Saturation": | |
return generate(prompt, loss_function=diffusion_loss.saturation) | |
elif loss_function == "Elastic Deformation": | |
return generate(prompt, loss_function=diffusion_loss.elastic_transform) | |
else: | |
return generate(prompt, loss_function=None) | |
else: | |
return None | |
iface = gr.Interface( | |
fn=process_input, | |
inputs=[ | |
gr.Textbox("prompt", label="Enter Prompt"), | |
gr.Dropdown(["No Loss", "Blue Channel", "Saturation", 'Elastic Deformation'], label='Choose Augmentation'), | |
gr.Button("Loss Function") | |
], | |
outputs = gr.Image(type="pil") | |
) | |
if __name__ == "__main__": | |
iface.launch(show_api=False, share=True) |