Spaces:
Runtime error
Runtime error
Commit
Β·
b55a663
1
Parent(s):
7cd8c59
added model along with final app
Browse files- .gitmodules +5 -0
- __pycache__/model.cpython-310.pyc +0 -0
- __pycache__/story_generator.cpython-310.pyc +0 -0
- app.py +116 -0
- model.py +80 -0
- network-snapshot-000120.pkl +3 -0
- requirements.txt +0 -0
- story_generator.py +20 -0
- style.css +12 -0
.gitmodules
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "stylegan3"]
|
2 |
+
path = stylegan3
|
3 |
+
url = https://github.com/NVlabs/stylegan3
|
4 |
+
|
5 |
+
|
__pycache__/model.cpython-310.pyc
ADDED
Binary file (3.73 kB). View file
|
|
__pycache__/story_generator.cpython-310.pyc
ADDED
Binary file (871 Bytes). View file
|
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from os import pipe
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from model import Model
|
7 |
+
from typing import Tuple
|
8 |
+
from diffusers import StableDiffusionPipeline
|
9 |
+
import gradio as gr
|
10 |
+
from story_generator import StoryGenerator
|
11 |
+
import torch
|
12 |
+
|
13 |
+
TITLE = ''
|
14 |
+
DESCRIPTION = '''# StyleGAN3
|
15 |
+
This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
model = Model()
|
20 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
21 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
22 |
+
pipe = pipe.to("cuda") # Remove or comment out this line if using CPU
|
23 |
+
|
24 |
+
sg = None
|
25 |
+
|
26 |
+
with gr.Blocks(css='style.css') as image_gen_block:
|
27 |
+
gr.Markdown(DESCRIPTION)
|
28 |
+
|
29 |
+
with gr.Tabs():
|
30 |
+
with gr.TabItem('Character'):
|
31 |
+
with gr.Row():
|
32 |
+
with gr.Column():
|
33 |
+
model_name = gr.Dropdown(list(model.MODEL_NAME_DICT.keys()),
|
34 |
+
value='FemaleHero-256-T',
|
35 |
+
label='Model')
|
36 |
+
seed = gr.Slider(0,
|
37 |
+
np.iinfo(np.uint32).max,
|
38 |
+
step=1,
|
39 |
+
value=0,
|
40 |
+
label='Seed')
|
41 |
+
psi = gr.Slider(0,
|
42 |
+
2,
|
43 |
+
step=0.05,
|
44 |
+
value=0.7,
|
45 |
+
label='Truncation psi')
|
46 |
+
tx = gr.Slider(-1,
|
47 |
+
1,
|
48 |
+
step=0.05,
|
49 |
+
value=0,
|
50 |
+
label='Translate X')
|
51 |
+
ty = gr.Slider(-1,
|
52 |
+
1,
|
53 |
+
step=0.05,
|
54 |
+
value=0,
|
55 |
+
label='Translate Y')
|
56 |
+
angle = gr.Slider(-180,
|
57 |
+
180,
|
58 |
+
step=5,
|
59 |
+
value=0,
|
60 |
+
label='Angle')
|
61 |
+
run_button = gr.Button('Run')
|
62 |
+
with gr.Column():
|
63 |
+
result = gr.Image(label='Result', elem_id='result')
|
64 |
+
|
65 |
+
# City generation tab
|
66 |
+
with gr.TabItem('City'):
|
67 |
+
with gr.Row():
|
68 |
+
generate_city_button = gr.Button("Generate City")
|
69 |
+
with gr.Row():
|
70 |
+
city_output = gr.Image(label="Generated City", elem_id="city_output")
|
71 |
+
|
72 |
+
with gr.TabItem('Story'):
|
73 |
+
with gr.Row():
|
74 |
+
api_key_input_sg = gr.Textbox(label="OpenAI API Key")
|
75 |
+
prompt_input = gr.Textbox(label='Prompt')
|
76 |
+
generate_button = gr.Button('Generate Story')
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
story_output = gr.Textbox(label='Generated Story',
|
80 |
+
placeholder='Click "Generate Story" to see the story',
|
81 |
+
readonly=True)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
def generate_story(prompt, api_key):
|
86 |
+
sg = StoryGenerator(api_key)
|
87 |
+
return sg.generate_story(prompt)
|
88 |
+
|
89 |
+
def generate_city():
|
90 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
91 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
92 |
+
pipe = pipe.to("cuda")
|
93 |
+
city_prompt = f"A metropolis city, HD"
|
94 |
+
city_image = pipe(city_prompt).images[0]
|
95 |
+
return city_image
|
96 |
+
|
97 |
+
generate_city_button.click(fn=generate_city, inputs=None, outputs=city_output)
|
98 |
+
model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
|
99 |
+
run_button.click(fn=model.generate_image,
|
100 |
+
inputs=[
|
101 |
+
model_name,
|
102 |
+
seed,
|
103 |
+
psi,
|
104 |
+
tx,
|
105 |
+
ty,
|
106 |
+
angle,
|
107 |
+
],
|
108 |
+
outputs=result)
|
109 |
+
generate_button.click(fn=generate_story, inputs=[prompt_input, api_key_input_sg], outputs=story_output)
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
image_gen_block.queue().launch(show_api=False)
|
115 |
+
|
116 |
+
|
model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import pickle
|
6 |
+
import sys
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
current_dir = pathlib.Path(__file__).parent
|
13 |
+
submodule_dir = current_dir / 'stylegan3'
|
14 |
+
sys.path.insert(0, submodule_dir.as_posix())
|
15 |
+
|
16 |
+
class Model:
|
17 |
+
MODEL_NAME_DICT = {'FemaleHero-256-T': 'network-snapshot-000120.pkl'}
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
21 |
+
self._download_all_models()
|
22 |
+
self.model_name = 'FemaleHero-256-T'
|
23 |
+
self.model = self._load_model(self.model_name)
|
24 |
+
|
25 |
+
def _load_model(self, model_name: str) -> nn.Module:
|
26 |
+
file_name = self.MODEL_NAME_DICT[model_name]
|
27 |
+
with open(file_name, 'rb') as f:
|
28 |
+
model = pickle.load(f)['G_ema']
|
29 |
+
model.eval()
|
30 |
+
model.to(self.device)
|
31 |
+
return model
|
32 |
+
|
33 |
+
def set_model(self, model_name: str) -> None:
|
34 |
+
if model_name == self.model_name:
|
35 |
+
return
|
36 |
+
self.model_name = model_name
|
37 |
+
self.model = self._load_model(model_name)
|
38 |
+
|
39 |
+
def _download_all_models(self):
|
40 |
+
pass
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
|
44 |
+
mat = np.eye(3)
|
45 |
+
sin = np.sin(angle / 360 * np.pi * 2)
|
46 |
+
cos = np.cos(angle / 360 * np.pi * 2)
|
47 |
+
mat[0][0] = cos
|
48 |
+
mat[0][1] = sin
|
49 |
+
mat[0][2] = translate[0]
|
50 |
+
mat[1][0] = -sin
|
51 |
+
mat[1][1] = cos
|
52 |
+
mat[1][2] = translate[1]
|
53 |
+
return mat
|
54 |
+
|
55 |
+
def generate_z(self, seed: int) -> torch.Tensor:
|
56 |
+
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
|
57 |
+
z = np.random.RandomState(seed).randn(1, self.model.z_dim)
|
58 |
+
return torch.from_numpy(z).float().to(self.device)
|
59 |
+
|
60 |
+
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
|
61 |
+
tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
62 |
+
return tensor.cpu().numpy()
|
63 |
+
|
64 |
+
def set_transform(self, tx: float, ty: float, angle: float) -> None:
|
65 |
+
mat = self.make_transform((tx, ty), angle)
|
66 |
+
mat = np.linalg.inv(mat)
|
67 |
+
self.model.synthesis.input.transform.copy_(torch.from_numpy(mat))
|
68 |
+
|
69 |
+
@torch.inference_mode()
|
70 |
+
def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor:
|
71 |
+
return self.model(z, label, truncation_psi=truncation_psi)
|
72 |
+
|
73 |
+
def generate_image(self, model_name: str, seed: int, truncation_psi: float, tx: float, ty: float, angle: float) -> np.ndarray:
|
74 |
+
self.set_model(model_name)
|
75 |
+
self.set_transform(tx, ty, angle)
|
76 |
+
z = self.generate_z(seed)
|
77 |
+
label = torch.zeros([1, self.model.c_dim], device=self.device)
|
78 |
+
out = self.generate(z, label, truncation_psi=truncation_psi)
|
79 |
+
out = self.postprocess(out)
|
80 |
+
return out[0]
|
network-snapshot-000120.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:494b89ca839757d1c76a9a1b5eb0be5ce01e0d66eee8a55e1563deba9734c7a6
|
3 |
+
size 343475301
|
requirements.txt
ADDED
Binary file (7.59 kB). View file
|
|
story_generator.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
|
3 |
+
class StoryGenerator:
|
4 |
+
def __init__(self, api_key):
|
5 |
+
|
6 |
+
openai.api_key = api_key
|
7 |
+
self.completion_model = "text-davinci-002"
|
8 |
+
|
9 |
+
|
10 |
+
def generate_story(self, prompt):
|
11 |
+
prompt = f"Once upon a time, {prompt}"
|
12 |
+
response = openai.Completion.create(
|
13 |
+
engine=self.completion_model,
|
14 |
+
prompt=prompt,
|
15 |
+
max_tokens=1024,
|
16 |
+
n=1,
|
17 |
+
stop=None,
|
18 |
+
temperature=0.5,
|
19 |
+
)
|
20 |
+
return response.choices[0].text
|
style.css
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
4 |
+
div#result {
|
5 |
+
max-width: 600px;
|
6 |
+
max-height: 600px;
|
7 |
+
}
|
8 |
+
img#visitor-badge {
|
9 |
+
display: block;
|
10 |
+
margin: auto;
|
11 |
+
}
|
12 |
+
|