File size: 6,199 Bytes
e49d8aa 1c7a008 df513b0 e49d8aa 3344c31 e49d8aa 1c7a008 e49d8aa a9a2195 e49d8aa 1c7a008 e49d8aa 1c7a008 df513b0 1c7a008 df513b0 1c7a008 02bdc55 1c7a008 df513b0 02bdc55 df513b0 e49d8aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# lib/code_reviewer.py
# Import necessary libraries
import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import requests
import zipfile
import io
# Custom Imports
from typing import List, Dict
class CodeReviewer:
def __init__(self, model_name: str = "facebook/incoder-1B"):
"""
Initializes the code reviewer with the specified language model.
Args:
model_name (str): The name of the pre-trained model to use.
"""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
# Load code standards checklist
self.code_standards = self.load_code_standards()
def load_code_standards(self) -> Dict:
"""
Loads the code standards checklist from a JSON file.
Returns:
Dict: The code standards in dictionary form.
"""
standards_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "standards", "code_standards.json")
with open(standards_path, 'r') as f:
return json.load(f)
def generate_prompt(self, code: str) -> str:
"""
Generates a review prompt for the input code based on the loaded standards.
Args:
code (str): The code to be reviewed.
Returns:
str: The prompt used for reviewing the code.
"""
# Build prompt from code standards
prompt = "You are an expert Ansible code reviewer. Review the following script thoroughly for the specified standards:\n\n"
for category in self.code_standards["code_standards"]:
prompt += f"{category['category']}:\n"
for standard in category['standards']:
prompt += f"- {standard['description']}\n"
prompt += "\nHere is the code:\n"
return prompt + code
def review_code(self, code: str) -> str:
"""
Uses the model to generate a review for the provided code.
Args:
code (str): The code to be reviewed.
Returns:
str: The review generated by the model.
"""
prompt = self.generate_prompt(code)
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
# Remove unsupported keys from model input
inputs = {k: v for k, v in inputs.items() if k in self.model.forward.__code__.co_varnames}
output = self.model.generate(**inputs, max_length=512)
review_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return review_text
def fine_tune_model(self, dataset, output_dir="./fine_tuned_incoder"):
"""
Fine-tunes the model with a custom dataset.
Args:
dataset: The dataset used for fine-tuning.
output_dir (str): Directory where the fine-tuned model will be saved.
"""
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
num_train_epochs=3,
logging_dir="./logs",
save_steps=10_000,
logging_steps=500,
evaluation_strategy="steps",
save_total_limit=2
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"]
)
# Start fine-tuning
trainer.train()
# Save the fine-tuned model
self.model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"Fine-tuned model saved at {output_dir}")
class ReviewManager:
def __init__(self, reviewer: CodeReviewer):
"""
Initializes the review manager with a given reviewer.
Args:
reviewer (CodeReviewer): An instance of the CodeReviewer class.
"""
self.reviewer = reviewer
def download_repo(self, repo_url: str, branch: str, token: str, download_path: str):
"""
Downloads a GitHub repository as a ZIP file and extracts it.
Args:
repo_url (str): The GitHub repository URL.
branch (str): The branch or tag to download.
token (str): The GitHub personal access token for authentication.
download_path (str): The path to extract the downloaded repository.
"""
zip_url = f"{repo_url}/archive/refs/heads/{branch}.zip"
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(zip_url, headers=headers)
if response.status_code == 200:
with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
zip_ref.extractall(download_path)
else:
raise Exception(f"Failed to download repository. Status code: {response.status_code}, Message: {response.text}")
def process_files(self, file_paths: List[str]) -> List[Dict[str, str]]:
"""
Processes multiple files for review.
Args:
file_paths (List[str]): List of file paths to be reviewed.
Returns:
List[Dict[str, str]]: A list containing review data for each file.
"""
reviews = []
for file_path in file_paths:
with open(file_path, 'r') as file:
code = file.read()
review = self.reviewer.review_code(code)
reviews.append({"filename": os.path.basename(file_path), "review": review})
return reviews
def save_reviews_to_json(self, reviews: List[Dict[str, str]], output_path: str):
"""
Saves the review data to a JSON file.
Args:
reviews (List[Dict[str, str]]): The list of reviews to save.
output_path (str): The path to save the JSON output.
"""
with open(output_path, 'w') as json_file:
json.dump(reviews, json_file, indent=4)
|