Spaces:
Sleeping
Sleeping
JunzhaoSun
commited on
Commit
·
634b9bc
1
Parent(s):
14f33b1
对比多种搜索方式
Browse files
app.py
CHANGED
@@ -9,9 +9,11 @@ import os
|
|
9 |
checkpoint = "gpt2-large"
|
10 |
# checkpoint = "/innev/open-ai/huggingface/models/gpt2-large"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
12 |
-
model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
|
|
13 |
|
14 |
-
|
|
|
15 |
# text = 'Who was Jim Henson ? Jim Henson was a'
|
16 |
|
17 |
# 编码一段文本
|
@@ -22,7 +24,6 @@ def generate(text):
|
|
22 |
# shape为 torch.Size([1, 11])
|
23 |
tokens_tensor = torch.tensor([indexed_tokens])
|
24 |
|
25 |
-
|
26 |
# 设置为evaluation模式,去取消激活dropout等模块。
|
27 |
# 在huggingface/transformers框架中,默认就是eval模式
|
28 |
model.eval()
|
@@ -51,21 +52,61 @@ def generate(text):
|
|
51 |
|
52 |
return predicted_text
|
53 |
|
54 |
-
|
55 |
-
def
|
56 |
-
|
57 |
text = prompts
|
58 |
total = 1
|
59 |
while text[-1] != "." and total < 20:
|
60 |
-
text =
|
61 |
print("Index %s: %s" % (total, text))
|
62 |
total = total + 1
|
63 |
|
64 |
return text, total
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
title = "GPT2 large"
|
68 |
|
|
|
|
|
69 |
description = """
|
70 |
本例为使用GPT2模型的简单推测语句DEMO,输入前面的句子,推测出后面的句子。
|
71 |
|
@@ -73,21 +114,31 @@ description = """
|
|
73 |
"""
|
74 |
|
75 |
examples = [
|
76 |
-
["
|
77 |
-
["
|
78 |
-
["My name is
|
79 |
-
["My name is
|
80 |
-
["My name is
|
|
|
81 |
]
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
gr.Interface(
|
84 |
-
fn=
|
85 |
-
inputs=
|
|
|
|
|
|
|
86 |
outputs=[
|
87 |
-
gr.Text(label="
|
88 |
gr.Text(label="循环次数"),
|
89 |
],
|
90 |
title=title,
|
91 |
description=description,
|
|
|
92 |
examples=examples,
|
93 |
).launch()
|
|
|
9 |
checkpoint = "gpt2-large"
|
10 |
# checkpoint = "/innev/open-ai/huggingface/models/gpt2-large"
|
11 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
12 |
+
# model = AutoModelForCausalLM.from_pretrained(checkpoint)
|
13 |
+
model = AutoModelForCausalLM.from_pretrained(checkpoint, pad_token_id=tokenizer.eos_token_id)
|
14 |
|
15 |
+
# 简单生成
|
16 |
+
def sampleGen(text):
|
17 |
# text = 'Who was Jim Henson ? Jim Henson was a'
|
18 |
|
19 |
# 编码一段文本
|
|
|
24 |
# shape为 torch.Size([1, 11])
|
25 |
tokens_tensor = torch.tensor([indexed_tokens])
|
26 |
|
|
|
27 |
# 设置为evaluation模式,去取消激活dropout等模块。
|
28 |
# 在huggingface/transformers框架中,默认就是eval模式
|
29 |
model.eval()
|
|
|
52 |
|
53 |
return predicted_text
|
54 |
|
55 |
+
# 关键词预测 生成文本
|
56 |
+
def loopGen(prompts):
|
|
|
57 |
text = prompts
|
58 |
total = 1
|
59 |
while text[-1] != "." and total < 20:
|
60 |
+
text = sampleGen(text)
|
61 |
print("Index %s: %s" % (total, text))
|
62 |
total = total + 1
|
63 |
|
64 |
return text, total
|
65 |
|
66 |
+
# 贪心搜索 生成文本
|
67 |
+
def greedySearch(prompts):
|
68 |
+
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
|
69 |
+
|
70 |
+
# generate the result with greedy search
|
71 |
+
output = model.generate(input_ids, max_length=128)
|
72 |
+
text = tokenizer.decode(output[0], skip_special_tokens=True)
|
73 |
+
return text, 1
|
74 |
+
|
75 |
+
# 随机方法 生成文本
|
76 |
+
def randomSearch(prompts):
|
77 |
+
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
|
78 |
+
|
79 |
+
# generate the result with random search
|
80 |
+
torch.manual_seed(0.)
|
81 |
+
output = model.generate(input_ids, do_sample=True, max_length=128, top_p=0.95, top_k=0)
|
82 |
+
text = tokenizer.decode(output[0], skip_special_tokens=True)
|
83 |
+
return text, 1
|
84 |
+
|
85 |
+
# 对比搜索 生成文本
|
86 |
+
def contrastiveSearch(prompts):
|
87 |
+
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
|
88 |
+
|
89 |
+
# generate the result with contrastive search
|
90 |
+
output = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
|
91 |
+
text = tokenizer.decode(output[0], skip_special_tokens=True)
|
92 |
+
|
93 |
+
return text, 1
|
94 |
+
|
95 |
+
def predict(searchType, prompts):
|
96 |
+
if searchType == "贪心搜索":
|
97 |
+
return greedySearch(prompts)
|
98 |
+
elif searchType == "随机方法":
|
99 |
+
return randomSearch(prompts)
|
100 |
+
elif searchType == "对比搜索":
|
101 |
+
return contrastiveSearch(prompts)
|
102 |
+
else:
|
103 |
+
return loopGen(prompts)
|
104 |
+
|
105 |
|
106 |
title = "GPT2 large"
|
107 |
|
108 |
+
searchMapping = ['关键词预测', '贪心搜索', '随机方法', '对比搜索']
|
109 |
+
|
110 |
description = """
|
111 |
本例为使用GPT2模型的简单推测语句DEMO,输入前面的句子,推测出后面的句子。
|
112 |
|
|
|
114 |
"""
|
115 |
|
116 |
examples = [
|
117 |
+
[None, "DeepMind Company is", None],
|
118 |
+
[None, "Who was Jim Henson ? Jim Henson was a", None],
|
119 |
+
[None, "My name is Julien and I like to", None],
|
120 |
+
[None, "My name is Thomas and my main", None],
|
121 |
+
[None, "My name is Mariama, my favorite", None],
|
122 |
+
[None, "My name is Clara and I am", None],
|
123 |
]
|
124 |
|
125 |
+
article = """
|
126 |
+
## 文章参考
|
127 |
+
- [在 Transformers 中使用对比搜索生成可媲美人类水平的文本 🤗](https://mp.weixin.qq.com/s/mydQLDlGUzFJuNBCIYc3CA)
|
128 |
+
"""
|
129 |
+
|
130 |
gr.Interface(
|
131 |
+
fn=predict,
|
132 |
+
inputs=[
|
133 |
+
gr.Radio(label="搜索方法", choices=searchMapping, value="关键词预测"),
|
134 |
+
gr.Text(label="输入前置语句"),
|
135 |
+
],
|
136 |
outputs=[
|
137 |
+
gr.Text(label="生成文本"),
|
138 |
gr.Text(label="循环次数"),
|
139 |
],
|
140 |
title=title,
|
141 |
description=description,
|
142 |
+
article=article,
|
143 |
examples=examples,
|
144 |
).launch()
|