dnnsdunca commited on
Commit
38dfef4
·
verified ·
1 Parent(s): a4d6124

Create src/agent.py

Browse files
Files changed (1) hide show
  1. src/agent.py +36 -0
src/agent.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ class CodingAgent:
5
+ def __init__(self, model_path):
6
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ self.model = AutoModelForCausalLM.from_pretrained(model_path).to(self.device)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+
10
+ def generate_code(self, prompt, max_length=512, temperature=0.7, top_k=50, top_p=0.95):
11
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
12
+
13
+ with torch.no_grad():
14
+ outputs = self.model.generate(
15
+ **inputs,
16
+ max_length=max_length,
17
+ temperature=temperature,
18
+ top_k=top_k,
19
+ top_p=top_p,
20
+ do_sample=True,
21
+ num_return_sequences=1,
22
+ )
23
+
24
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
25
+
26
+ def answer_coding_question(self, question):
27
+ prompt = f"As a coding assistant, please answer the following question:\n\nQuestion: {question}\n\nAnswer:"
28
+ return self.generate_code(prompt)
29
+
30
+ def explain_code(self, code):
31
+ prompt = f"Please explain the following code:\n\n```python\n{code}\n```\n\nExplanation:"
32
+ return self.generate_code(prompt)
33
+
34
+ def suggest_improvements(self, code):
35
+ prompt = f"Please suggest improvements for the following code:\n\n```python\n{code}\n```\n\nSuggestions:"
36
+ return self.generate_code(prompt)