JunzhaoSun commited on
Commit
634b9bc
·
1 Parent(s): 14f33b1

对比多种搜索方式

Browse files
Files changed (1) hide show
  1. app.py +66 -15
app.py CHANGED
@@ -9,9 +9,11 @@ import os
9
  checkpoint = "gpt2-large"
10
  # checkpoint = "/innev/open-ai/huggingface/models/gpt2-large"
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
- model = AutoModelForCausalLM.from_pretrained(checkpoint)
 
13
 
14
- def generate(text):
 
15
  # text = 'Who was Jim Henson ? Jim Henson was a'
16
 
17
  # 编码一段文本
@@ -22,7 +24,6 @@ def generate(text):
22
  # shape为 torch.Size([1, 11])
23
  tokens_tensor = torch.tensor([indexed_tokens])
24
 
25
-
26
  # 设置为evaluation模式,去取消激活dropout等模块。
27
  # 在huggingface/transformers框架中,默认就是eval模式
28
  model.eval()
@@ -51,21 +52,61 @@ def generate(text):
51
 
52
  return predicted_text
53
 
54
-
55
- def doloop(prompts):
56
-
57
  text = prompts
58
  total = 1
59
  while text[-1] != "." and total < 20:
60
- text = generate(text)
61
  print("Index %s: %s" % (total, text))
62
  total = total + 1
63
 
64
  return text, total
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  title = "GPT2 large"
68
 
 
 
69
  description = """
70
  本例为使用GPT2模型的简单推测语句DEMO,输入前面的句子,推测出后面的句子。
71
 
@@ -73,21 +114,31 @@ description = """
73
  """
74
 
75
  examples = [
76
- ["Who was Jim Henson ? Jim Henson was a", None],
77
- ["My name is Julien and I like to", None],
78
- ["My name is Thomas and my main", None],
79
- ["My name is Mariama, my favorite", None],
80
- ["My name is Clara and I am", None],
 
81
  ]
82
 
 
 
 
 
 
83
  gr.Interface(
84
- fn=doloop,
85
- inputs=gr.Text(label="输入前置语句"),
 
 
 
86
  outputs=[
87
- gr.Text(label="补全后输出"),
88
  gr.Text(label="循环次数"),
89
  ],
90
  title=title,
91
  description=description,
 
92
  examples=examples,
93
  ).launch()
 
9
  checkpoint = "gpt2-large"
10
  # checkpoint = "/innev/open-ai/huggingface/models/gpt2-large"
11
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
12
+ # model = AutoModelForCausalLM.from_pretrained(checkpoint)
13
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, pad_token_id=tokenizer.eos_token_id)
14
 
15
+ # 简单生成
16
+ def sampleGen(text):
17
  # text = 'Who was Jim Henson ? Jim Henson was a'
18
 
19
  # 编码一段文本
 
24
  # shape为 torch.Size([1, 11])
25
  tokens_tensor = torch.tensor([indexed_tokens])
26
 
 
27
  # 设置为evaluation模式,去取消激活dropout等模块。
28
  # 在huggingface/transformers框架中,默认就是eval模式
29
  model.eval()
 
52
 
53
  return predicted_text
54
 
55
+ # 关键词预测 生成文本
56
+ def loopGen(prompts):
 
57
  text = prompts
58
  total = 1
59
  while text[-1] != "." and total < 20:
60
+ text = sampleGen(text)
61
  print("Index %s: %s" % (total, text))
62
  total = total + 1
63
 
64
  return text, total
65
 
66
+ # 贪心搜索 生成文本
67
+ def greedySearch(prompts):
68
+ input_ids = tokenizer(prompts, return_tensors='pt').input_ids
69
+
70
+ # generate the result with greedy search
71
+ output = model.generate(input_ids, max_length=128)
72
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
73
+ return text, 1
74
+
75
+ # 随机方法 生成文本
76
+ def randomSearch(prompts):
77
+ input_ids = tokenizer(prompts, return_tensors='pt').input_ids
78
+
79
+ # generate the result with random search
80
+ torch.manual_seed(0.)
81
+ output = model.generate(input_ids, do_sample=True, max_length=128, top_p=0.95, top_k=0)
82
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
83
+ return text, 1
84
+
85
+ # 对比搜索 生成文本
86
+ def contrastiveSearch(prompts):
87
+ input_ids = tokenizer(prompts, return_tensors='pt').input_ids
88
+
89
+ # generate the result with contrastive search
90
+ output = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
91
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
92
+
93
+ return text, 1
94
+
95
+ def predict(searchType, prompts):
96
+ if searchType == "贪心搜索":
97
+ return greedySearch(prompts)
98
+ elif searchType == "随机方法":
99
+ return randomSearch(prompts)
100
+ elif searchType == "对比搜索":
101
+ return contrastiveSearch(prompts)
102
+ else:
103
+ return loopGen(prompts)
104
+
105
 
106
  title = "GPT2 large"
107
 
108
+ searchMapping = ['关键词预测', '贪心搜索', '随机方法', '对比搜索']
109
+
110
  description = """
111
  本例为使用GPT2模型的简单推测语句DEMO,输入前面的句子,推测出后面的句子。
112
 
 
114
  """
115
 
116
  examples = [
117
+ [None, "DeepMind Company is", None],
118
+ [None, "Who was Jim Henson ? Jim Henson was a", None],
119
+ [None, "My name is Julien and I like to", None],
120
+ [None, "My name is Thomas and my main", None],
121
+ [None, "My name is Mariama, my favorite", None],
122
+ [None, "My name is Clara and I am", None],
123
  ]
124
 
125
+ article = """
126
+ ## 文章参考
127
+ - [在 Transformers 中使用对比搜索生成可媲美人类水平的文本 🤗](https://mp.weixin.qq.com/s/mydQLDlGUzFJuNBCIYc3CA)
128
+ """
129
+
130
  gr.Interface(
131
+ fn=predict,
132
+ inputs=[
133
+ gr.Radio(label="搜索方法", choices=searchMapping, value="关键词预测"),
134
+ gr.Text(label="输入前置语句"),
135
+ ],
136
  outputs=[
137
+ gr.Text(label="生成文本"),
138
  gr.Text(label="循环次数"),
139
  ],
140
  title=title,
141
  description=description,
142
+ article=article,
143
  examples=examples,
144
  ).launch()