Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,56 @@
|
|
1 |
---
|
2 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
pipeline_tag: text-generation
|
6 |
+
tags:
|
7 |
+
- legal
|
8 |
+
- news
|
9 |
+
library_name: transformers
|
10 |
---
|
11 |
+
# GPT-Neo-1.3B SimCTG for Conditional News Generation
|
12 |
+
[SimCTG](https://github.com/yxuansu/SimCTG) model (released by Su et.al. in this [paper](https://arxiv.org/abs/2202.06417)), leveraging [GPT-Neo-1.3B](https://huggingface.co/EleutherAI/gpt-neo-1.3B) (a large language model).
|
13 |
+
|
14 |
+
## Data Details
|
15 |
+
It was trained on a large news corpus containing news content from 19 different publishers. Detailed dataset configuration is as follow:
|
16 |
+
|
17 |
+
| Publisher | Data Number |
|
18 |
+
| :--------------: | :---------: |
|
19 |
+
| Guardian | 250,000 |
|
20 |
+
| BBC | 240,872 |
|
21 |
+
| WashingtonPost | 167,401 |
|
22 |
+
| USAToday | 234,648 |
|
23 |
+
| Reuters | 822,110 |
|
24 |
+
| NYT (New York Times) | 245,150 |
|
25 |
+
| CNBC | 231,060 |
|
26 |
+
| Hill | 205,410 |
|
27 |
+
| People | 132,630 |
|
28 |
+
| CNN | 121,760 |
|
29 |
+
| Vice | 97,750 |
|
30 |
+
| Mashable | 91,100 |
|
31 |
+
| Refinery | 84,100 |
|
32 |
+
| BI (Business Insider) | 53,014 |
|
33 |
+
| TechCrunch | 49,040 |
|
34 |
+
| Verge | 48,327 |
|
35 |
+
| TMZ | 46,490 |
|
36 |
+
| Axios | 44,280 |
|
37 |
+
| Vox | 44120 |
|
38 |
+
|
39 |
+
## Training Details
|
40 |
+
We use the prompt template `Publisher: {vox} article: ` for training. We trained the model about 3 epochs on 3 NVIDIA A40 GPU.
|
41 |
+
|
42 |
+
## How to use
|
43 |
+
```python
|
44 |
+
>>> from transformers import GPTNeoForCausalLM, AutoTokenizer
|
45 |
+
|
46 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("PahaII/gpt-neo-1.3b-simctg-NewsCtrlGen")
|
47 |
+
>>> model = GPTNeoForCausalLM.from_pretrained("PahaII/gpt-neo-1.3b-simctg-NewsCtrlGen")
|
48 |
+
|
49 |
+
>>> publisher = "Reuters"
|
50 |
+
>>> assert publisher in ["Reuters", "NYT", "CNBC", "Hill", "People", "CNN", "Vice", "Mashable", "Refinery", "BI", "TechCrunch", "Verge", "TMZ", "Axios", "Vox", "Guardian", "BBCNews", "WashingtonPost", "USAToday"]
|
51 |
+
>>> prompt = f"Publisher: {publisher.lower()} article: "
|
52 |
+
|
53 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
54 |
+
>>> out = model.generate(**inputs, penalty_alpha=0.6)
|
55 |
+
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
|
56 |
+
```
|