Fictionista / app.py
teamnassim's picture
attempt to use stylegan3
689dfd2
raw
history blame
4.21 kB
from __future__ import annotations
from os import pipe
import gradio as gr
import numpy as np
from model import Model
from diffusers import StableDiffusionPipeline
import gradio as gr
from story_generator import StoryGenerator
import torch
TITLE = ''
DESCRIPTION = '''# StyleGAN3
This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
'''
model = Model()
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda") # Remove or comment out this line if using CPU
sg = None
with gr.Blocks(css='style.css') as image_gen_block:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.TabItem('Character'):
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(list(model.MODEL_NAME_DICT.keys()),
value='FemaleHero-256-T',
label='Model')
seed = gr.Slider(0,
np.iinfo(np.uint32).max,
step=1,
value=0,
label='Seed')
psi = gr.Slider(0,
2,
step=0.05,
value=0.7,
label='Truncation psi')
tx = gr.Slider(-1,
1,
step=0.05,
value=0,
label='Translate X')
ty = gr.Slider(-1,
1,
step=0.05,
value=0,
label='Translate Y')
angle = gr.Slider(-180,
180,
step=5,
value=0,
label='Angle')
run_button = gr.Button('Run')
with gr.Column():
result = gr.Image(label='Result', elem_id='result')
# City generation tab
with gr.TabItem('City'):
with gr.Row():
generate_city_button = gr.Button("Generate City")
with gr.Row():
city_output = gr.Image(label="Generated City", elem_id="city_output")
with gr.TabItem('Story'):
with gr.Row():
api_key_input_sg = gr.Textbox(label="OpenAI API Key")
prompt_input = gr.Textbox(label='Prompt')
generate_button = gr.Button('Generate Story')
with gr.Row():
story_output = gr.Textbox(label='Generated Story',
placeholder='Click "Generate Story" to see the story',
readonly=True)
def generate_story(prompt, api_key):
sg = StoryGenerator(api_key)
return sg.generate_story(prompt)
def generate_city():
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
city_prompt = f"A metropolis city, HD"
city_image = pipe(city_prompt).images[0]
return city_image
generate_city_button.click(fn=generate_city, inputs=None, outputs=city_output)
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
run_button.click(fn=model.generate_image,
inputs=[
model_name,
seed,
psi,
tx,
ty,
angle,
],
outputs=result)
generate_button.click(fn=generate_story, inputs=[prompt_input, api_key_input_sg], outputs=story_output)
image_gen_block.queue().launch(show_api=False)