Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
from os import pipe | |
import gradio as gr | |
import numpy as np | |
from model import Model | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr | |
from story_generator import StoryGenerator | |
import torch | |
TITLE = '' | |
DESCRIPTION = '''# StyleGAN3 | |
This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3). | |
''' | |
model = Model() | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") # Remove or comment out this line if using CPU | |
sg = None | |
with gr.Blocks(css='style.css') as image_gen_block: | |
gr.Markdown(DESCRIPTION) | |
with gr.Tabs(): | |
with gr.TabItem('Character'): | |
with gr.Row(): | |
with gr.Column(): | |
model_name = gr.Dropdown(list(model.MODEL_NAME_DICT.keys()), | |
value='FemaleHero-256-T', | |
label='Model') | |
seed = gr.Slider(0, | |
np.iinfo(np.uint32).max, | |
step=1, | |
value=0, | |
label='Seed') | |
psi = gr.Slider(0, | |
2, | |
step=0.05, | |
value=0.7, | |
label='Truncation psi') | |
tx = gr.Slider(-1, | |
1, | |
step=0.05, | |
value=0, | |
label='Translate X') | |
ty = gr.Slider(-1, | |
1, | |
step=0.05, | |
value=0, | |
label='Translate Y') | |
angle = gr.Slider(-180, | |
180, | |
step=5, | |
value=0, | |
label='Angle') | |
run_button = gr.Button('Run') | |
with gr.Column(): | |
result = gr.Image(label='Result', elem_id='result') | |
# City generation tab | |
with gr.TabItem('City'): | |
with gr.Row(): | |
generate_city_button = gr.Button("Generate City") | |
with gr.Row(): | |
city_output = gr.Image(label="Generated City", elem_id="city_output") | |
with gr.TabItem('Story'): | |
with gr.Row(): | |
api_key_input_sg = gr.Textbox(label="OpenAI API Key") | |
prompt_input = gr.Textbox(label='Prompt') | |
generate_button = gr.Button('Generate Story') | |
with gr.Row(): | |
story_output = gr.Textbox(label='Generated Story', | |
placeholder='Click "Generate Story" to see the story', | |
readonly=True) | |
def generate_story(prompt, api_key): | |
sg = StoryGenerator(api_key) | |
return sg.generate_story(prompt) | |
def generate_city(): | |
model_id = "runwayml/stable-diffusion-v1-5" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe = pipe.to("cuda") | |
city_prompt = f"A metropolis city, HD" | |
city_image = pipe(city_prompt).images[0] | |
return city_image | |
generate_city_button.click(fn=generate_city, inputs=None, outputs=city_output) | |
model_name.change(fn=model.set_model, inputs=model_name, outputs=None) | |
run_button.click(fn=model.generate_image, | |
inputs=[ | |
model_name, | |
seed, | |
psi, | |
tx, | |
ty, | |
angle, | |
], | |
outputs=result) | |
generate_button.click(fn=generate_story, inputs=[prompt_input, api_key_input_sg], outputs=story_output) | |
image_gen_block.queue().launch(show_api=False) | |