Vishisht16's picture
Add application code
188e720
raw
history blame
3.77 kB
import torch
from transformers import RobertaTokenizer, RobertaModel, logging as hf_logging
from typing import List, Dict, Optional
from code_summarizer.language_parsers import extract_code_snippets, SUPPORTED_EXTENSIONS
from pathlib import Path
import numpy as np
import logging
log = logging.getLogger(__name__)
hf_logging.set_verbosity_error()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Summarizer using device: {device}")
MODEL_LOADED = False
tokenizer = None
model = None
try:
log.info("Loading CodeBERT tokenizer/model...")
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
model = RobertaModel.from_pretrained("microsoft/codebert-base")
model = model.to(device)
model.eval()
MODEL_LOADED = True
log.info("CodeBERT model loaded successfully.")
except Exception as e:
log.error(f"Failed to load CodeBERT model: {e}", exc_info=True)
def get_embedding(code: str) -> Optional[List[float]]:
if not MODEL_LOADED or tokenizer is None or model is None:
return None
try:
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
return embedding.tolist()
except Exception as e:
log.warning(f"Failed to generate embedding: {e}. Snippet start: {code[:50]}...")
return None
def generate_summary(snippet: str) -> str:
try:
lines = snippet.strip().split('\n')
header = next((line.strip() for line in lines if line.strip() and not (line.strip().startswith('#') or line.strip().startswith('//') or line.strip().startswith('/*'))), "")
header = (header[:100] + "...") if len(header) > 100 else header
return f"Function/method starting with `{header}`." if header else "N/A Summary"
except Exception:
return "Summary generation failed."
def summarize_file(file_path: Path, repo_url: str) -> List[Dict]:
language, snippets = extract_code_snippets(file_path)
if not snippets:
return []
results = []
log.debug(f"Summarizing {len(snippets)} snippets from {file_path}...")
for snippet in snippets:
if not snippet or snippet.isspace():
continue
embedding = get_embedding(snippet)
summary = generate_summary(snippet)
summary_data = {
"repo_url": repo_url,
"file_path": str(file_path.as_posix()),
"language": language,
"function_code": snippet,
"summary": summary,
}
if embedding is not None:
summary_data["embedding"] = embedding
results.append(summary_data)
return results
def summarize_repo(repo_dir: Path, repo_url: str) -> List[Dict]:
all_results = []
log.info(f"Starting summarization for repository: {repo_url}")
supported_extensions = set(SUPPORTED_EXTENSIONS.keys())
files_processed_count = 0
for file in repo_dir.rglob("*"):
if file.is_file() and file.suffix.lower() in supported_extensions:
log.debug(f"Processing file: {file}")
try:
file_results = summarize_file(file, repo_url)
if file_results:
all_results.extend(file_results)
files_processed_count += 1
except Exception as e:
log.error(f"Failed to process file {file}: {e}", exc_info=True)
log.info(f"Summarization complete for {repo_url}. Processed {files_processed_count} files, found {len(all_results)} functions.")
return all_results