TURX commited on
Commit
d8a9bd8
โ€ข
1 Parent(s): c6bdc06
Files changed (4) hide show
  1. Dockerfile +11 -0
  2. README.md +13 -5
  3. main.py +192 -0
  4. requirements.txt +3 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["funix", "main.py", "--host", "0.0.0.0", "--port", "7860", "--no-browser"]
README.md CHANGED
@@ -1,10 +1,18 @@
1
  ---
2
- title: Japanese Lm
3
- emoji: ๐Ÿ†
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: docker
 
 
 
 
 
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
  ---
2
+ title: Japanese LM
 
 
 
3
  sdk: docker
4
+ app_port: 7860
5
+ python_version: 3.10
6
+ models:
7
+ - ku-nlp/gpt2-medium-japanese-char
8
+ - TURX/chj-gpt2
9
+ - TURX/wakagpt
10
  pinned: false
11
  ---
12
 
13
+ Japanese Language Models
14
+ ---
15
+
16
+ Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison
17
+
18
+ Author: Ruixuan Tu ([email protected], https://turx.asia)
main.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # To run: funix main.py
2
+
3
+ from transformers import AutoTokenizer
4
+ from transformers import AutoModelForCausalLM
5
+ import typing
6
+ from funix import funix
7
+ from funix.hint import HTML
8
+
9
+ low_memory = True # Set to True to run on mobile devices
10
+
11
+ ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
12
+ chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2")
13
+ wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt")
14
+ ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char")
15
+ chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2")
16
+ wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt")
17
+
18
+ print("Models loaded successfully.")
19
+
20
+ model_name_map = {
21
+ "Kyoto University GPT-2 (Modern)": "ku-gpt2",
22
+ "CHJ GPT-2 (Classical)": "chj-gpt2",
23
+ "Waka GPT": "wakagpt",
24
+ }
25
+
26
+ waka_type_map = {
27
+ "kana": "[ไปฎๅ]",
28
+ "original": "[ๅŽŸๆ–‡]",
29
+ "aligned": "[ๆ•ดๅฝข]",
30
+ }
31
+
32
+
33
+ @funix(
34
+ title=" Home",
35
+ description="""
36
+ <h1>Japanese Language Models</h1><hr>
37
+ Final Project, STAT 453 Spring 2024, University of Wisconsin-Madison<br>
38
+ Author: Ruixuan Tu ([email protected], https://turx.asia)<hr>
39
+ Navigate the apps using the left sidebar.
40
+ """
41
+ )
42
+ def home():
43
+ return
44
+
45
+
46
+ @funix(disable=True)
47
+ def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str,
48
+ do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int
49
+ ) -> str:
50
+ global low_memory
51
+ inputs = tokenizer(prompt, return_tensors="pt").input_ids
52
+ outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences)
53
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
54
+
55
+
56
+ @funix(
57
+ title="Custom Prompt Japanese GPT-2",
58
+ description="""
59
+ <h1>Japanese GPT-2</h1><hr>
60
+ Let a GPT-2 model to complete a Japanese sentence for you.
61
+ """,
62
+ argument_labels={
63
+ "prompt": "Prompt in Japanese",
64
+ "model_type": "Model Type",
65
+ "max_new_tokens": "Max New Tokens to Generate",
66
+ "do_sample": "Do Sample",
67
+ "num_beams": "Number of Beams",
68
+ "num_beam_groups": "Number of Beam Groups",
69
+ "max_new_tokens": "Max New Tokens",
70
+ "temperature": "Temperature",
71
+ "top_k": "Top K",
72
+ "top_p": "Top P",
73
+ "repetition_penalty": "Repetition Penalty",
74
+ "num_return_sequences": "Number of Sequences to Return",
75
+ },
76
+ widgets={
77
+ "num_beams": "slider[1,10,1]",
78
+ "num_beam_groups": "slider[1,5,1]",
79
+ "max_new_tokens": "slider[1,512,1]",
80
+ "temperature": "slider[0.0,1.0,0.01]",
81
+ "top_k": "slider[1,100,0.1]",
82
+ "top_p": "slider[0.0,1.0,0.01]",
83
+ "repetition_penalty": "slider[1.0,2.0,0.01]",
84
+ "num_return_sequences": "slider[1,5,1]",
85
+ }
86
+ )
87
+ def prompt(prompt: str = "ใ“ใ‚“ใซใกใฏใ€‚", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)",
88
+ do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
89
+ ) -> HTML:
90
+ model_name = model_name_map[model_type]
91
+ if model_name == "ku-gpt2":
92
+ tokenizer = ku_gpt_tokenizer
93
+ model = ku_gpt_model
94
+ elif model_name == "chj-gpt2":
95
+ tokenizer = chj_gpt_tokenizer
96
+ model = chj_gpt_model
97
+ elif model_name == "wakagpt":
98
+ tokenizer = wakagpt_tokenizer
99
+ model = wakagpt_model
100
+ else:
101
+ raise NotImplementedError(f"Unsupported model: {model_name}")
102
+ generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
103
+ return HTML("".join([f"<p>{i}</p>" for i in generated]))
104
+
105
+
106
+ @funix(
107
+ title="WakaGPT Poem Composer",
108
+ description="""
109
+ <h1>WakaGPT Poem Composer</h1><hr>
110
+ Generate a Japanese waka poem in 5-7-5-7-7 form using WakaGPT. A sample poem (Kokinshu 169) is provided below:<br>
111
+ Preface: ็ง‹็ซ‹ใคๆ—ฅใ‚ˆใ‚ใ‚‹<br>
112
+ Author: ๆ•่กŒ ่—คๅŽŸๆ•่กŒๆœ่‡ฃ (018)<br>
113
+ Kana (Kana only with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚โˆ’ใ‹ใ›ใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
114
+ Original (Kana + Kanji without Separator): ใ‚ใใใฌใจใ‚ใซใฏใ•ใ‚„ใ‹ใซ่ฆ‹ใˆใญใจใ‚‚้ขจใฎใŠใจใซใใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹<br>
115
+ Aligned (Kana + Kanji with Separator): ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’่ฆ‹ใˆใญใจใ‚‚โˆ’้ขจใฎใŠใจใซใโˆ’ใŠใจใ‚ใ‹ใ‚Œใฌใ‚‹
116
+ """,
117
+ argument_labels={
118
+ "preface": "Preface (Kotobagaki) in Japanese (optional)",
119
+ "author": "Author Name in Japanese (optional)",
120
+ "first_line": "First Line of Poem in Japanese (optional)",
121
+ "type": "Waka Type",
122
+ "remaining_lines": "Remaining Lines of Poem",
123
+ "do_sample": "Do Sample",
124
+ "num_beams": "Number of Beams",
125
+ "num_beam_groups": "Number of Beam Groups",
126
+ "temperature": "Temperature",
127
+ "top_k": "Top K",
128
+ "top_p": "Top P",
129
+ "repetition_penalty": "Repetition Penalty",
130
+ "num_return_sequences": "Number of Sequences to Return (at Maximum)",
131
+ },
132
+ widgets={
133
+ "remaining_lines": "slider[1,5,1]",
134
+ "num_beams": "slider[1,10,1]",
135
+ "num_beam_groups": "slider[1,5,1]",
136
+ "temperature": "slider[0.0,1.0,0.01]",
137
+ "top_k": "slider[1,100,0.1]",
138
+ "top_p": "slider[0.0,1.0,0.01]",
139
+ "repetition_penalty": "slider[1.0,2.0,0.01]",
140
+ "num_return_sequences": "slider[1,5,1]",
141
+ }
142
+ )
143
+ def waka(preface: str = "", author: str = "", first_line: str = "ใ‚ใใใฌใจโˆ’ใ‚ใซใฏใ•ใ‚„ใ‹ใซโˆ’ใฟใˆใญใจใ‚‚", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2,
144
+ do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1
145
+ ) -> HTML:
146
+ waka_prompt = ""
147
+ if preface:
148
+ waka_prompt += "[่ฉžๆ›ธ] " + preface + "\n"
149
+ if author:
150
+ waka_prompt += "[ไฝœ่€…] " + author + "\n"
151
+ token_counts = [5, 7, 5, 7, 7]
152
+ max_new_tokens = sum(token_counts[-remaining_lines:])
153
+ first_line = first_line.strip()
154
+
155
+ # add separators
156
+ if type.lower() in ["kana", "aligned"]:
157
+ if first_line == "":
158
+ max_new_tokens += 4
159
+ else:
160
+ first_line += "โˆ’" if first_line[-1] != "โˆ’" else first_line
161
+ max_new_tokens += remaining_lines - 1 # remaining separators
162
+
163
+ waka_prompt += waka_type_map[type.lower()] + " " + first_line
164
+ info = f"""
165
+ Prompt: {waka_prompt}<br>
166
+ Max New Tokens: {max_new_tokens}<br>
167
+ """
168
+ yield info + "Generating Poem..."
169
+ generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences)
170
+
171
+ removed = 0
172
+ checked_generated = []
173
+ if type.lower() in ["kana", "aligned"]:
174
+ def check(seq):
175
+ poem = first_line + seq[len(waka_prompt) - 1:]
176
+ parts = poem.split("โˆ’")
177
+ if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)):
178
+ checked_generated.append(poem)
179
+ else:
180
+ nonlocal removed
181
+ removed += 1
182
+ for i in generated:
183
+ check(i)
184
+ else:
185
+ checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated]
186
+
187
+ generated = [f"<p>{i}</p>" for i in checked_generated]
188
+ yield info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}"
189
+
190
+
191
+ if __name__ == "__main__":
192
+ print(prompt("ใ“ใ‚“ใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5))
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ funix