Spaces:
Sleeping
Sleeping
File size: 4,947 Bytes
14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 9254ad0 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 3af6189 14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 634b9bc 14f33b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#!/usr/local/bin/python3
#-*- coding:utf-8 -*-
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
checkpoint = "gpt2-large"
# checkpoint = "/innev/open-ai/huggingface/models/gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# model = AutoModelForCausalLM.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, pad_token_id=tokenizer.eos_token_id)
# 简单生成
def sampleGen(text):
# text = 'Who was Jim Henson ? Jim Henson was a'
# 编码一段文本
# 编码后为[8241, 373, 5395, 367, 19069, 5633, 5395, 367, 19069, 373, 257]
indexed_tokens = tokenizer.encode(text)
# 转换为pytorch tensor
# tensor([[ 8241, 373, 5395, 367, 19069, 5633, 5395, 367, 19069, 373, 257]])
# shape为 torch.Size([1, 11])
tokens_tensor = torch.tensor([indexed_tokens])
# 设置为evaluation模式,去取消激活dropout等模块。
# 在huggingface/transformers框架中,默认就是eval模式
model.eval()
# 预测所有token
with torch.no_grad():
# 将输入tensor输入,就得到了模型的输出,非常简单
# outputs是一个元组,所有huggingface/transformers模型的输出都是元组
# 本初的元组有两个,第一个是预测得分(没经过softmax之前的,也叫作logits),
# 第二个是past,里面的attention计算的key value值
# 此时我们需要的是第一个值
outputs = model(tokens_tensor)
# predictions shape为 torch.Size([1, 11, 50257]),
# 也就是11个词每个词的预测得分(没经过softmax之前的)
# 也叫做logits
predictions = outputs[0]
# 我们需要预测下一个单词,所以是使用predictions第一个batch,最后一个词的logits去计算
# predicted_index = 582,通过计算最大得分的索引得到的
predicted_index = torch.argmax(predictions[0, -1, :]).item()
# 反向解码为我们需要的文本
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
# predicted_text = tokenizer.decode([predicted_index])
# 解码后的文本:'Who was Jim Henson? Jim Henson was a man'
# 成功预测出单词 'man'
return predicted_text
# 关键词预测 生成文本
def loopGen(prompts):
text = prompts
total = 1
while text[-1] != "." and total < 20:
text = sampleGen(text)
print("Index %s: %s" % (total, text))
total = total + 1
return text, total
# 贪心搜索 生成文本
def greedySearch(prompts):
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
# generate the result with greedy search
output = model.generate(input_ids, max_length=128)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text, 1
# 随机方法 生成文本
def randomSearch(prompts):
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
# generate the result with random search
torch.manual_seed(0.)
output = model.generate(input_ids, do_sample=True, max_length=128, top_p=0.95, top_k=0)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text, 1
# 对比搜索 生成文本
def contrastiveSearch(prompts):
input_ids = tokenizer(prompts, return_tensors='pt').input_ids
# generate the result with contrastive search
output = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text, 1
def predict(searchType, prompts='Who was Jim Henson ? Jim Henson was a'):
if searchType == "贪心搜索":
return greedySearch(prompts)
elif searchType == "随机方法":
return randomSearch(prompts)
elif searchType == "对比搜索":
return contrastiveSearch(prompts)
else:
return loopGen(prompts)
title = "GPT2 large"
searchMapping = ['关键词预测', '贪心搜索', '随机方法', '对比搜索']
description = """
本例为使用GPT2模型的简单推测语句DEMO,输入前面的句子,推测出后面的句子。
使用原始模型,未经过微调。只支持英文输入输出。
"""
examples = [
[None, "DeepMind Company is", None],
[None, "Who was Jim Henson ? Jim Henson was a", None],
[None, "China is", None]
]
article = """
## 文章参考
- [在 Transformers 中使用对比搜索生成可媲美人类水平的文本 🤗](https://mp.weixin.qq.com/s/mydQLDlGUzFJuNBCIYc3CA)
"""
gr.Interface(
fn=predict,
inputs=[
gr.Radio(label="搜索方法", choices=searchMapping, value="关键词预测"),
gr.Text(label="输入前置语句"),
],
outputs=[
gr.Text(label="生成文本"),
gr.Text(label="循环次数"),
],
title=title,
description=description,
article=article,
examples=examples,
).launch() |