teamnassim commited on
Commit
b55a663
Β·
1 Parent(s): 7cd8c59

added model along with final app

Browse files
.gitmodules ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [submodule "stylegan3"]
2
+ path = stylegan3
3
+ url = https://github.com/NVlabs/stylegan3
4
+
5
+
__pycache__/model.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
__pycache__/story_generator.cpython-310.pyc ADDED
Binary file (871 Bytes). View file
 
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from os import pipe
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from model import Model
7
+ from typing import Tuple
8
+ from diffusers import StableDiffusionPipeline
9
+ import gradio as gr
10
+ from story_generator import StoryGenerator
11
+ import torch
12
+
13
+ TITLE = ''
14
+ DESCRIPTION = '''# StyleGAN3
15
+ This is an unofficial demo for [https://github.com/NVlabs/stylegan3](https://github.com/NVlabs/stylegan3).
16
+ '''
17
+
18
+
19
+ model = Model()
20
+ model_id = "runwayml/stable-diffusion-v1-5"
21
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
22
+ pipe = pipe.to("cuda") # Remove or comment out this line if using CPU
23
+
24
+ sg = None
25
+
26
+ with gr.Blocks(css='style.css') as image_gen_block:
27
+ gr.Markdown(DESCRIPTION)
28
+
29
+ with gr.Tabs():
30
+ with gr.TabItem('Character'):
31
+ with gr.Row():
32
+ with gr.Column():
33
+ model_name = gr.Dropdown(list(model.MODEL_NAME_DICT.keys()),
34
+ value='FemaleHero-256-T',
35
+ label='Model')
36
+ seed = gr.Slider(0,
37
+ np.iinfo(np.uint32).max,
38
+ step=1,
39
+ value=0,
40
+ label='Seed')
41
+ psi = gr.Slider(0,
42
+ 2,
43
+ step=0.05,
44
+ value=0.7,
45
+ label='Truncation psi')
46
+ tx = gr.Slider(-1,
47
+ 1,
48
+ step=0.05,
49
+ value=0,
50
+ label='Translate X')
51
+ ty = gr.Slider(-1,
52
+ 1,
53
+ step=0.05,
54
+ value=0,
55
+ label='Translate Y')
56
+ angle = gr.Slider(-180,
57
+ 180,
58
+ step=5,
59
+ value=0,
60
+ label='Angle')
61
+ run_button = gr.Button('Run')
62
+ with gr.Column():
63
+ result = gr.Image(label='Result', elem_id='result')
64
+
65
+ # City generation tab
66
+ with gr.TabItem('City'):
67
+ with gr.Row():
68
+ generate_city_button = gr.Button("Generate City")
69
+ with gr.Row():
70
+ city_output = gr.Image(label="Generated City", elem_id="city_output")
71
+
72
+ with gr.TabItem('Story'):
73
+ with gr.Row():
74
+ api_key_input_sg = gr.Textbox(label="OpenAI API Key")
75
+ prompt_input = gr.Textbox(label='Prompt')
76
+ generate_button = gr.Button('Generate Story')
77
+
78
+ with gr.Row():
79
+ story_output = gr.Textbox(label='Generated Story',
80
+ placeholder='Click "Generate Story" to see the story',
81
+ readonly=True)
82
+
83
+
84
+
85
+ def generate_story(prompt, api_key):
86
+ sg = StoryGenerator(api_key)
87
+ return sg.generate_story(prompt)
88
+
89
+ def generate_city():
90
+ model_id = "runwayml/stable-diffusion-v1-5"
91
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
92
+ pipe = pipe.to("cuda")
93
+ city_prompt = f"A metropolis city, HD"
94
+ city_image = pipe(city_prompt).images[0]
95
+ return city_image
96
+
97
+ generate_city_button.click(fn=generate_city, inputs=None, outputs=city_output)
98
+ model_name.change(fn=model.set_model, inputs=model_name, outputs=None)
99
+ run_button.click(fn=model.generate_image,
100
+ inputs=[
101
+ model_name,
102
+ seed,
103
+ psi,
104
+ tx,
105
+ ty,
106
+ angle,
107
+ ],
108
+ outputs=result)
109
+ generate_button.click(fn=generate_story, inputs=[prompt_input, api_key_input_sg], outputs=story_output)
110
+
111
+
112
+
113
+
114
+ image_gen_block.queue().launch(show_api=False)
115
+
116
+
model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import sys
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ current_dir = pathlib.Path(__file__).parent
13
+ submodule_dir = current_dir / 'stylegan3'
14
+ sys.path.insert(0, submodule_dir.as_posix())
15
+
16
+ class Model:
17
+ MODEL_NAME_DICT = {'FemaleHero-256-T': 'network-snapshot-000120.pkl'}
18
+
19
+ def __init__(self):
20
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
21
+ self._download_all_models()
22
+ self.model_name = 'FemaleHero-256-T'
23
+ self.model = self._load_model(self.model_name)
24
+
25
+ def _load_model(self, model_name: str) -> nn.Module:
26
+ file_name = self.MODEL_NAME_DICT[model_name]
27
+ with open(file_name, 'rb') as f:
28
+ model = pickle.load(f)['G_ema']
29
+ model.eval()
30
+ model.to(self.device)
31
+ return model
32
+
33
+ def set_model(self, model_name: str) -> None:
34
+ if model_name == self.model_name:
35
+ return
36
+ self.model_name = model_name
37
+ self.model = self._load_model(model_name)
38
+
39
+ def _download_all_models(self):
40
+ pass
41
+
42
+ @staticmethod
43
+ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
44
+ mat = np.eye(3)
45
+ sin = np.sin(angle / 360 * np.pi * 2)
46
+ cos = np.cos(angle / 360 * np.pi * 2)
47
+ mat[0][0] = cos
48
+ mat[0][1] = sin
49
+ mat[0][2] = translate[0]
50
+ mat[1][0] = -sin
51
+ mat[1][1] = cos
52
+ mat[1][2] = translate[1]
53
+ return mat
54
+
55
+ def generate_z(self, seed: int) -> torch.Tensor:
56
+ seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
57
+ z = np.random.RandomState(seed).randn(1, self.model.z_dim)
58
+ return torch.from_numpy(z).float().to(self.device)
59
+
60
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
61
+ tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
62
+ return tensor.cpu().numpy()
63
+
64
+ def set_transform(self, tx: float, ty: float, angle: float) -> None:
65
+ mat = self.make_transform((tx, ty), angle)
66
+ mat = np.linalg.inv(mat)
67
+ self.model.synthesis.input.transform.copy_(torch.from_numpy(mat))
68
+
69
+ @torch.inference_mode()
70
+ def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor:
71
+ return self.model(z, label, truncation_psi=truncation_psi)
72
+
73
+ def generate_image(self, model_name: str, seed: int, truncation_psi: float, tx: float, ty: float, angle: float) -> np.ndarray:
74
+ self.set_model(model_name)
75
+ self.set_transform(tx, ty, angle)
76
+ z = self.generate_z(seed)
77
+ label = torch.zeros([1, self.model.c_dim], device=self.device)
78
+ out = self.generate(z, label, truncation_psi=truncation_psi)
79
+ out = self.postprocess(out)
80
+ return out[0]
network-snapshot-000120.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:494b89ca839757d1c76a9a1b5eb0be5ce01e0d66eee8a55e1563deba9734c7a6
3
+ size 343475301
requirements.txt ADDED
Binary file (7.59 kB). View file
 
story_generator.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+ class StoryGenerator:
4
+ def __init__(self, api_key):
5
+
6
+ openai.api_key = api_key
7
+ self.completion_model = "text-davinci-002"
8
+
9
+
10
+ def generate_story(self, prompt):
11
+ prompt = f"Once upon a time, {prompt}"
12
+ response = openai.Completion.create(
13
+ engine=self.completion_model,
14
+ prompt=prompt,
15
+ max_tokens=1024,
16
+ n=1,
17
+ stop=None,
18
+ temperature=0.5,
19
+ )
20
+ return response.choices[0].text
style.css ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ div#result {
5
+ max-width: 600px;
6
+ max-height: 600px;
7
+ }
8
+ img#visitor-badge {
9
+ display: block;
10
+ margin: auto;
11
+ }
12
+