S-Dreamer commited on
Commit
352ff58
·
verified ·
1 Parent(s): f75c091

Create generation_fast.py

Browse files
Files changed (1) hide show
  1. generation_fast.py +29 -0
generation_fast.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+ class CodeGenerator:
6
+ def __init__(self, model_name):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
11
+
12
+ def generate_code(self, nl_input, max_length=256, num_beams=4, early_stopping=True):
13
+ inputs = self.tokenizer(nl_input, return_tensors="pt").to(self.device)
14
+ outputs = self.model.generate(
15
+ **inputs,
16
+ max_length=max_length,
17
+ num_beams=num_beams,
18
+ early_stopping=early_stopping,
19
+ )
20
+ generated_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+ return generated_code
22
+
23
+ if __name__ == "__main__":
24
+ model_name = "S-Dreamer/PyCodeT5"
25
+ generator = CodeGenerator(model_name)
26
+
27
+ nl_input = "Write a Python function to calculate the factorial of a number."
28
+ generated_code = generator.generate_code(nl_input)
29
+ print(generated_code)