christopher commited on
Commit
1e88db6
·
1 Parent(s): 3763669

Fix config usage and include pytorch weights

Browse files
Files changed (1) hide show
  1. clone_sentdex_model_tokenizer.py +14 -3
clone_sentdex_model_tokenizer.py CHANGED
@@ -1,9 +1,20 @@
1
  #!/usr/bin/env python
2
- from transformers import AutoTokenizer, TFAutoModelForCausalLM
3
  import tensorflow as tf
4
 
 
 
 
 
 
 
 
 
5
  tokenizer = AutoTokenizer.from_pretrained("Sentdex/GPyT")
6
- model = TFAutoModelForCausalLM.from_pretrained("Sentdex/GPyT", from_pt=True)
 
7
 
 
8
  tokenizer.save_pretrained(save_directory='./')
9
- model.save_pretrained(save_directory='./', saved_model=True)
 
 
1
  #!/usr/bin/env python
2
+ from transformers import AutoTokenizer, TFAutoModelForCausalLM, AutoModelForCausalLM, GPT2Config
3
  import tensorflow as tf
4
 
5
+ task_specific_params = {
6
+ "text-generation": {
7
+ "do_sample": False,
8
+ "max_length": 50
9
+ }
10
+ }
11
+
12
+ config = GPT2Config.from_pretrained("Sentdex/GPyT", _name_or_path='prophetikai/code-gpt', use_cache=True, task_specific_params=task_specific_params)
13
  tokenizer = AutoTokenizer.from_pretrained("Sentdex/GPyT")
14
+ tf_model = TFAutoModelForCausalLM.from_pretrained("Sentdex/GPyT", config=config)
15
+ pytorch_model = AutoModelForCausalLM.from_pretrained("Sentdex/GPyT", config=config)
16
 
17
+ config.save_pretrained('./')
18
  tokenizer.save_pretrained(save_directory='./')
19
+ tf_model.save_pretrained(save_directory='./', saved_model=True, version='sentdex')
20
+ pytorch_model.save_pretrained(save_directory='./')