vsagar100 commited on
Commit
1c7a008
·
verified ·
1 Parent(s): 3344c31

Update lib/code_reviewer.py

Browse files
Files changed (1) hide show
  1. lib/code_reviewer.py +40 -5
lib/code_reviewer.py CHANGED
@@ -4,7 +4,7 @@
4
  import os
5
  import json
6
  import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import requests
9
  import zipfile
10
  import io
@@ -33,7 +33,7 @@ class CodeReviewer:
33
  Returns:
34
  Dict: The code standards in dictionary form.
35
  """
36
- standards_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "standards", "ansible_code_standards.json")
37
  with open(standards_path, 'r') as f:
38
  return json.load(f)
39
 
@@ -74,6 +74,40 @@ class CodeReviewer:
74
  review_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
75
  return review_text
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class ReviewManager:
78
  def __init__(self, reviewer: CodeReviewer):
79
  """
@@ -84,17 +118,19 @@ class ReviewManager:
84
  """
85
  self.reviewer = reviewer
86
 
87
- def download_repo(self, repo_url: str, token: str, download_path: str):
88
  """
89
  Downloads a GitHub repository as a ZIP file and extracts it.
90
 
91
  Args:
92
  repo_url (str): The GitHub repository URL.
 
93
  token (str): The GitHub personal access token for authentication.
94
  download_path (str): The path to extract the downloaded repository.
95
  """
 
96
  headers = {"Authorization": f"Bearer {token}"}
97
- response = requests.get(repo_url, headers=headers)
98
  if response.status_code == 200:
99
  with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
100
  zip_ref.extractall(download_path)
@@ -129,4 +165,3 @@ class ReviewManager:
129
  """
130
  with open(output_path, 'w') as json_file:
131
  json.dump(reviews, json_file, indent=4)
132
-
 
4
  import os
5
  import json
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
8
  import requests
9
  import zipfile
10
  import io
 
33
  Returns:
34
  Dict: The code standards in dictionary form.
35
  """
36
+ standards_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "standards", "code_standards.json")
37
  with open(standards_path, 'r') as f:
38
  return json.load(f)
39
 
 
74
  review_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
75
  return review_text
76
 
77
+ def fine_tune_model(self, dataset, output_dir="./fine_tuned_incoder"):
78
+ """
79
+ Fine-tunes the model with a custom dataset.
80
+
81
+ Args:
82
+ dataset: The dataset used for fine-tuning.
83
+ output_dir (str): Directory where the fine-tuned model will be saved.
84
+ """
85
+ training_args = TrainingArguments(
86
+ output_dir=output_dir,
87
+ per_device_train_batch_size=4,
88
+ num_train_epochs=3,
89
+ logging_dir="./logs",
90
+ save_steps=10_000,
91
+ logging_steps=500,
92
+ evaluation_strategy="steps",
93
+ save_total_limit=2
94
+ )
95
+
96
+ trainer = Trainer(
97
+ model=self.model,
98
+ args=training_args,
99
+ train_dataset=dataset["train"],
100
+ eval_dataset=dataset["validation"]
101
+ )
102
+
103
+ # Start fine-tuning
104
+ trainer.train()
105
+
106
+ # Save the fine-tuned model
107
+ self.model.save_pretrained(output_dir)
108
+ self.tokenizer.save_pretrained(output_dir)
109
+ print(f"Fine-tuned model saved at {output_dir}")
110
+
111
  class ReviewManager:
112
  def __init__(self, reviewer: CodeReviewer):
113
  """
 
118
  """
119
  self.reviewer = reviewer
120
 
121
+ def download_repo(self, repo_url: str, branch: str, token: str, download_path: str):
122
  """
123
  Downloads a GitHub repository as a ZIP file and extracts it.
124
 
125
  Args:
126
  repo_url (str): The GitHub repository URL.
127
+ branch (str): The branch or tag to download.
128
  token (str): The GitHub personal access token for authentication.
129
  download_path (str): The path to extract the downloaded repository.
130
  """
131
+ zip_url = f"{repo_url}/archive/refs/heads/{branch}.zip"
132
  headers = {"Authorization": f"Bearer {token}"}
133
+ response = requests.get(zip_url, headers=headers)
134
  if response.status_code == 200:
135
  with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
136
  zip_ref.extractall(download_path)
 
165
  """
166
  with open(output_path, 'w') as json_file:
167
  json.dump(reviews, json_file, indent=4)