julien-c HF staff commited on
Commit
32ac4ff
·
1 Parent(s): 3458085

Migrate model card from transformers-repo

Browse files

Read announcement at https://discuss.huggingface.co/t/announcement-all-model-cards-will-be-migrated-to-hf-co-model-repos/2755
Original file history: https://github.com/huggingface/transformers/commits/master/model_cards/gaochangkuan/model_dir/README.md

Files changed (1) hide show
  1. README.md +66 -0
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generating Chinese poetry by topic.
2
+
3
+ ```python
4
+ from transformers import *
5
+
6
+ tokenizer = BertTokenizer.from_pretrained("gaochangkuan/model_dir")
7
+
8
+ model = AutoModelWithLMHead.from_pretrained("gaochangkuan/model_dir")
9
+
10
+
11
+ prompt= '''<s>田园躬耕'''
12
+
13
+ length= 84
14
+ stop_token='</s>'
15
+
16
+ temperature = 1.2
17
+
18
+ repetition_penalty=1.3
19
+
20
+ k= 30
21
+ p= 0.95
22
+
23
+ device ='cuda'
24
+ seed=2020
25
+ no_cuda=False
26
+
27
+ prompt_text = prompt if prompt else input("Model prompt >>> ")
28
+
29
+ encoded_prompt = tokenizer.encode(
30
+ '<s>'+prompt_text+'<sep>',
31
+ add_special_tokens=False,
32
+ return_tensors="pt"
33
+ )
34
+
35
+ encoded_prompt = encoded_prompt.to(device)
36
+
37
+ output_sequences = model.generate(
38
+ input_ids=encoded_prompt,
39
+ max_length=length,
40
+ min_length=10,
41
+ do_sample=True,
42
+ early_stopping=True,
43
+ num_beams=10,
44
+ temperature=temperature,
45
+ top_k=k,
46
+ top_p=p,
47
+ repetition_penalty=repetition_penalty,
48
+ bad_words_ids=None,
49
+ bos_token_id=tokenizer.bos_token_id,
50
+ pad_token_id=tokenizer.pad_token_id,
51
+ eos_token_id=tokenizer.eos_token_id,
52
+ length_penalty=1.2,
53
+ no_repeat_ngram_size=2,
54
+ num_return_sequences=1,
55
+ attention_mask=None,
56
+ decoder_start_token_id=tokenizer.bos_token_id,)
57
+
58
+
59
+ generated_sequence = output_sequences[0].tolist()
60
+ text = tokenizer.decode(generated_sequence)
61
+
62
+
63
+ text = text[: text.find(stop_token) if stop_token else None]
64
+
65
+ print(''.join(text).replace(' ','').replace('<pad>','').replace('<s>',''))
66
+ ```