Spaces:
Running
Running
karanravindra
commited on
Commit
•
7b9568f
1
Parent(s):
25dbf42
make space
Browse files- .gitignore +2 -0
- app.py +104 -0
- requirements.txt +68 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
model
|
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import gradio as gr
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
from namegenerator import Model, ModelConfig
|
6 |
+
|
7 |
+
torch.set_grad_enabled(False)
|
8 |
+
|
9 |
+
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>", "0", "1"]
|
10 |
+
tokens = special_tokens + list("abcdefghijklmnopqrstuvwxyz")
|
11 |
+
char_to_idx = {char: idx for idx, char in enumerate(tokens)}
|
12 |
+
idx_to_char = {idx: char for idx, char in enumerate(tokens)}
|
13 |
+
|
14 |
+
hf_hub_download(
|
15 |
+
"karanravindra/namegenerator", "model.pth", subfolder="model", local_dir="."
|
16 |
+
)
|
17 |
+
model = Model(
|
18 |
+
ModelConfig(
|
19 |
+
vocab_size=len(tokens),
|
20 |
+
embedding_dim=48,
|
21 |
+
num_layers=6,
|
22 |
+
max_length=24, # not padding to nearest 32 because max length of names is 17 - bump this for `theoretically` better performance
|
23 |
+
q_heads=12,
|
24 |
+
kv_heads=4,
|
25 |
+
m=4,
|
26 |
+
tie_weights=False,
|
27 |
+
)
|
28 |
+
)
|
29 |
+
model.load_state_dict(
|
30 |
+
torch.load("model/model.pth", map_location="cpu", weights_only=True)
|
31 |
+
)
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
|
35 |
+
def decode(encoded_name: list[int], strip_special_tokens: bool = True) -> str:
|
36 |
+
if strip_special_tokens:
|
37 |
+
encoded_name = [
|
38 |
+
idx
|
39 |
+
for idx in encoded_name
|
40 |
+
if idx
|
41 |
+
not in [char_to_idx["<sos>"], char_to_idx["<eos>"], char_to_idx["<pad>"]]
|
42 |
+
]
|
43 |
+
return "".join([idx_to_char[idx] for idx in encoded_name])
|
44 |
+
|
45 |
+
|
46 |
+
def decode_batch(
|
47 |
+
encoded_names: torch.Tensor, strip_special_tokens: bool = True
|
48 |
+
) -> list[str]:
|
49 |
+
return [
|
50 |
+
decode(encoded_name.tolist(), strip_special_tokens)
|
51 |
+
for encoded_name in encoded_names
|
52 |
+
]
|
53 |
+
|
54 |
+
|
55 |
+
def generate_names(n=16, gender=None, temperature=0.6):
|
56 |
+
model.eval()
|
57 |
+
if gender is None:
|
58 |
+
genders = torch.cat(
|
59 |
+
[
|
60 |
+
torch.tensor([[char_to_idx["0"]]]).repeat(n // 2, 1),
|
61 |
+
torch.tensor([[char_to_idx["1"]]]).repeat(n // 2, 1),
|
62 |
+
],
|
63 |
+
dim=0,
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
gender = char_to_idx[str(gender)]
|
67 |
+
genders = torch.full((n, 1), gender)
|
68 |
+
|
69 |
+
start_token = torch.tensor([[char_to_idx["<sos>"]]]).repeat(n, 1)
|
70 |
+
start_token = torch.cat([start_token, genders], dim=1)
|
71 |
+
|
72 |
+
generated = start_token
|
73 |
+
for _ in range(22):
|
74 |
+
output = model(generated) / temperature
|
75 |
+
|
76 |
+
token = torch.multinomial(F.softmax(output[:, -1], dim=1), 1)
|
77 |
+
|
78 |
+
generated = torch.cat([generated, token], dim=1)
|
79 |
+
|
80 |
+
if token.all() == char_to_idx["<pad>"]:
|
81 |
+
break
|
82 |
+
|
83 |
+
return decode_batch(generated, strip_special_tokens=True)
|
84 |
+
|
85 |
+
|
86 |
+
def generate_name(gender: str, num_names: int, temperature: float):
|
87 |
+
names = generate_names(num_names, gender, temperature)
|
88 |
+
names = [name[1:].capitalize() for name in names]
|
89 |
+
|
90 |
+
return "\n".join(names)
|
91 |
+
|
92 |
+
|
93 |
+
demo = gr.Interface(
|
94 |
+
generate_name,
|
95 |
+
gr.Radio(["Male", "Female"], label="Sex", type="index"),
|
96 |
+
gr.TextArea(lines=16, label="Generated Names"),
|
97 |
+
additional_inputs=[
|
98 |
+
gr.Number(16, label="Number of Names"),
|
99 |
+
gr.Slider(0.1, 2, 0.6, label="Temperature", step=0.1),
|
100 |
+
],
|
101 |
+
title="Name Generator",
|
102 |
+
description="Generates names based on sex using a GPT-2 model trained on names.",
|
103 |
+
)
|
104 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
namegenerator @ git+https://github.com/karanravindra/namegenerator@main
|
2 |
+
gradio >= 5.0.0
|
3 |
+
huggingface_hub >= 0.26.0
|
4 |
+
|
5 |
+
# This file was autogenerated via `uv export`.
|
6 |
+
appnope==0.1.4 ; platform_system == 'Darwin'
|
7 |
+
asttokens==2.4.1
|
8 |
+
cffi==1.17.1 ; implementation_name == 'pypy'
|
9 |
+
colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
|
10 |
+
comm==0.2.2
|
11 |
+
debugpy==1.8.7
|
12 |
+
decorator==5.1.1
|
13 |
+
einops==0.8.0
|
14 |
+
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
15 |
+
executing==2.1.0
|
16 |
+
filelock==3.16.1
|
17 |
+
fsspec==2024.10.0
|
18 |
+
ipykernel==6.29.5
|
19 |
+
ipython==8.29.0
|
20 |
+
ipywidgets==8.1.5
|
21 |
+
jedi==0.19.1
|
22 |
+
jupyter-client==8.6.3
|
23 |
+
jupyter-core==5.7.2
|
24 |
+
jupyterlab-widgets==3.0.13
|
25 |
+
matplotlib-inline==0.1.7
|
26 |
+
mpmath==1.3.0
|
27 |
+
nest-asyncio==1.6.0
|
28 |
+
networkx==3.4.2
|
29 |
+
numpy==2.1.2
|
30 |
+
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
31 |
+
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
32 |
+
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
33 |
+
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
34 |
+
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
35 |
+
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
36 |
+
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
37 |
+
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
38 |
+
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
39 |
+
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
40 |
+
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
41 |
+
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
|
42 |
+
packaging==24.1
|
43 |
+
parso==0.8.4
|
44 |
+
pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
45 |
+
platformdirs==4.3.6
|
46 |
+
polars==1.12.0
|
47 |
+
prompt-toolkit==3.0.48
|
48 |
+
psutil==6.1.0
|
49 |
+
ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
|
50 |
+
pure-eval==0.2.3
|
51 |
+
pycparser==2.22 ; implementation_name == 'pypy'
|
52 |
+
pygments==2.18.0
|
53 |
+
python-dateutil==2.9.0.post0
|
54 |
+
pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
|
55 |
+
pyzmq==26.2.0
|
56 |
+
setuptools==75.3.0 ; python_full_version >= '3.12'
|
57 |
+
six==1.16.0
|
58 |
+
stack-data==0.6.3
|
59 |
+
sympy==1.13.1
|
60 |
+
torch==2.5.0
|
61 |
+
torchinfo==1.8.0
|
62 |
+
tornado==6.4.1
|
63 |
+
tqdm==4.66.6
|
64 |
+
traitlets==5.14.3
|
65 |
+
triton==3.1.0 ; python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'
|
66 |
+
typing-extensions==4.12.2
|
67 |
+
wcwidth==0.2.13
|
68 |
+
widgetsnbextension==4.0.13
|