updated model card examples
Browse files
README.md
CHANGED
@@ -18,11 +18,11 @@ set a seed for reproducibility:
|
|
18 |
>>> generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
19 |
>>> set_seed(42)
|
20 |
>>> generator("COVID-19 is", max_length=20, num_return_sequences=5, do_sample=True)
|
21 |
-
[{'generated_text': 'COVID-19 is a
|
22 |
-
{'generated_text': 'COVID-19 is
|
23 |
-
{'generated_text': 'COVID-19 is a
|
24 |
-
{'generated_text': 'COVID-19 is a
|
25 |
-
{'generated_text': 'COVID-19 is
|
26 |
```
|
27 |
|
28 |
Here is how to use this model to get the features of a given text in PyTorch:
|
@@ -39,6 +39,7 @@ output = model(**encoded_input)
|
|
39 |
Beam-search decoding:
|
40 |
|
41 |
```python
|
|
|
42 |
from transformers import BioGptTokenizer, BioGptLMHeadModel, set_seed
|
43 |
|
44 |
tokenizer = BioGptTokenizer.from_pretrained("kamalkraj/biogpt")
|
@@ -49,12 +50,13 @@ inputs = tokenizer(sentence, return_tensors="pt")
|
|
49 |
|
50 |
set_seed(42)
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
tokenizer.decode(beam_output[0], skip_special_tokens=True)
|
59 |
'COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK), and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and more than 800,000 deaths.'
|
60 |
```
|
|
|
18 |
>>> generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
|
19 |
>>> set_seed(42)
|
20 |
>>> generator("COVID-19 is", max_length=20, num_return_sequences=5, do_sample=True)
|
21 |
+
[{'generated_text': 'COVID-19 is a disease that spreads worldwide and is currently found in a growing proportion of the population'},
|
22 |
+
{'generated_text': 'COVID-19 is one of the largest viral epidemics in the world.'},
|
23 |
+
{'generated_text': 'COVID-19 is a common condition affecting an estimated 1.1 million people in the United States alone.'},
|
24 |
+
{'generated_text': 'COVID-19 is a pandemic, the incidence has been increased in a manner similar to that in other'},
|
25 |
+
{'generated_text': 'COVID-19 is transmitted via droplets, air-borne, or airborne transmission.'}]
|
26 |
```
|
27 |
|
28 |
Here is how to use this model to get the features of a given text in PyTorch:
|
|
|
39 |
Beam-search decoding:
|
40 |
|
41 |
```python
|
42 |
+
import torch
|
43 |
from transformers import BioGptTokenizer, BioGptLMHeadModel, set_seed
|
44 |
|
45 |
tokenizer = BioGptTokenizer.from_pretrained("kamalkraj/biogpt")
|
|
|
50 |
|
51 |
set_seed(42)
|
52 |
|
53 |
+
with torch.no_grad():
|
54 |
+
beam_output = model.generate(**inputs,
|
55 |
+
min_length=100,
|
56 |
+
max_length=1024,
|
57 |
+
num_beams=5,
|
58 |
+
early_stopping=True
|
59 |
+
)
|
60 |
tokenizer.decode(beam_output[0], skip_special_tokens=True)
|
61 |
'COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK), and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and more than 800,000 deaths.'
|
62 |
```
|