karanravindra commited on
Commit
7b9568f
1 Parent(s): 25dbf42

make space

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +104 -0
  3. requirements.txt +68 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .venv
2
+ model
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+ from namegenerator import Model, ModelConfig
6
+
7
+ torch.set_grad_enabled(False)
8
+
9
+ special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>", "0", "1"]
10
+ tokens = special_tokens + list("abcdefghijklmnopqrstuvwxyz")
11
+ char_to_idx = {char: idx for idx, char in enumerate(tokens)}
12
+ idx_to_char = {idx: char for idx, char in enumerate(tokens)}
13
+
14
+ hf_hub_download(
15
+ "karanravindra/namegenerator", "model.pth", subfolder="model", local_dir="."
16
+ )
17
+ model = Model(
18
+ ModelConfig(
19
+ vocab_size=len(tokens),
20
+ embedding_dim=48,
21
+ num_layers=6,
22
+ max_length=24, # not padding to nearest 32 because max length of names is 17 - bump this for `theoretically` better performance
23
+ q_heads=12,
24
+ kv_heads=4,
25
+ m=4,
26
+ tie_weights=False,
27
+ )
28
+ )
29
+ model.load_state_dict(
30
+ torch.load("model/model.pth", map_location="cpu", weights_only=True)
31
+ )
32
+ model.eval()
33
+
34
+
35
+ def decode(encoded_name: list[int], strip_special_tokens: bool = True) -> str:
36
+ if strip_special_tokens:
37
+ encoded_name = [
38
+ idx
39
+ for idx in encoded_name
40
+ if idx
41
+ not in [char_to_idx["<sos>"], char_to_idx["<eos>"], char_to_idx["<pad>"]]
42
+ ]
43
+ return "".join([idx_to_char[idx] for idx in encoded_name])
44
+
45
+
46
+ def decode_batch(
47
+ encoded_names: torch.Tensor, strip_special_tokens: bool = True
48
+ ) -> list[str]:
49
+ return [
50
+ decode(encoded_name.tolist(), strip_special_tokens)
51
+ for encoded_name in encoded_names
52
+ ]
53
+
54
+
55
+ def generate_names(n=16, gender=None, temperature=0.6):
56
+ model.eval()
57
+ if gender is None:
58
+ genders = torch.cat(
59
+ [
60
+ torch.tensor([[char_to_idx["0"]]]).repeat(n // 2, 1),
61
+ torch.tensor([[char_to_idx["1"]]]).repeat(n // 2, 1),
62
+ ],
63
+ dim=0,
64
+ )
65
+ else:
66
+ gender = char_to_idx[str(gender)]
67
+ genders = torch.full((n, 1), gender)
68
+
69
+ start_token = torch.tensor([[char_to_idx["<sos>"]]]).repeat(n, 1)
70
+ start_token = torch.cat([start_token, genders], dim=1)
71
+
72
+ generated = start_token
73
+ for _ in range(22):
74
+ output = model(generated) / temperature
75
+
76
+ token = torch.multinomial(F.softmax(output[:, -1], dim=1), 1)
77
+
78
+ generated = torch.cat([generated, token], dim=1)
79
+
80
+ if token.all() == char_to_idx["<pad>"]:
81
+ break
82
+
83
+ return decode_batch(generated, strip_special_tokens=True)
84
+
85
+
86
+ def generate_name(gender: str, num_names: int, temperature: float):
87
+ names = generate_names(num_names, gender, temperature)
88
+ names = [name[1:].capitalize() for name in names]
89
+
90
+ return "\n".join(names)
91
+
92
+
93
+ demo = gr.Interface(
94
+ generate_name,
95
+ gr.Radio(["Male", "Female"], label="Sex", type="index"),
96
+ gr.TextArea(lines=16, label="Generated Names"),
97
+ additional_inputs=[
98
+ gr.Number(16, label="Number of Names"),
99
+ gr.Slider(0.1, 2, 0.6, label="Temperature", step=0.1),
100
+ ],
101
+ title="Name Generator",
102
+ description="Generates names based on sex using a GPT-2 model trained on names.",
103
+ )
104
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ namegenerator @ git+https://github.com/karanravindra/namegenerator@main
2
+ gradio >= 5.0.0
3
+ huggingface_hub >= 0.26.0
4
+
5
+ # This file was autogenerated via `uv export`.
6
+ appnope==0.1.4 ; platform_system == 'Darwin'
7
+ asttokens==2.4.1
8
+ cffi==1.17.1 ; implementation_name == 'pypy'
9
+ colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
10
+ comm==0.2.2
11
+ debugpy==1.8.7
12
+ decorator==5.1.1
13
+ einops==0.8.0
14
+ exceptiongroup==1.2.2 ; python_full_version < '3.11'
15
+ executing==2.1.0
16
+ filelock==3.16.1
17
+ fsspec==2024.10.0
18
+ ipykernel==6.29.5
19
+ ipython==8.29.0
20
+ ipywidgets==8.1.5
21
+ jedi==0.19.1
22
+ jupyter-client==8.6.3
23
+ jupyter-core==5.7.2
24
+ jupyterlab-widgets==3.0.13
25
+ matplotlib-inline==0.1.7
26
+ mpmath==1.3.0
27
+ nest-asyncio==1.6.0
28
+ networkx==3.4.2
29
+ numpy==2.1.2
30
+ nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and platform_system == 'Linux'
31
+ nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
32
+ nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
33
+ nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
34
+ nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and platform_system == 'Linux'
35
+ nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
36
+ nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and platform_system == 'Linux'
37
+ nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and platform_system == 'Linux'
38
+ nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and platform_system == 'Linux'
39
+ nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
40
+ nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
41
+ nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and platform_system == 'Linux'
42
+ packaging==24.1
43
+ parso==0.8.4
44
+ pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
45
+ platformdirs==4.3.6
46
+ polars==1.12.0
47
+ prompt-toolkit==3.0.48
48
+ psutil==6.1.0
49
+ ptyprocess==0.7.0 ; sys_platform != 'emscripten' and sys_platform != 'win32'
50
+ pure-eval==0.2.3
51
+ pycparser==2.22 ; implementation_name == 'pypy'
52
+ pygments==2.18.0
53
+ python-dateutil==2.9.0.post0
54
+ pywin32==308 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32'
55
+ pyzmq==26.2.0
56
+ setuptools==75.3.0 ; python_full_version >= '3.12'
57
+ six==1.16.0
58
+ stack-data==0.6.3
59
+ sympy==1.13.1
60
+ torch==2.5.0
61
+ torchinfo==1.8.0
62
+ tornado==6.4.1
63
+ tqdm==4.66.6
64
+ traitlets==5.14.3
65
+ triton==3.1.0 ; python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'
66
+ typing-extensions==4.12.2
67
+ wcwidth==0.2.13
68
+ widgetsnbextension==4.0.13