karthigakannaiyan's picture
Upload 9 files
f016346 verified
import gradio as gr
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
nltk.download('punkt')
# Load model and tokenizer
model_dir = "./codet5-base-multi-sum"
tokenizer = RobertaTokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def generate_comment(code_snippet, reference_comment):
# Add prefix for summarization task
prefixed_code = "summarize: " + code_snippet.strip()
input_ids = tokenizer(prefixed_code, return_tensors="pt").input_ids.to(device)
generated_ids = model.generate(input_ids, max_length=64, num_beams=4, early_stopping=True)
comment = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Tokenize and compute BLEU against user-provided reference
if reference_comment.strip():
ref_tokens = nltk.word_tokenize(reference_comment.lower())
hyp_tokens = nltk.word_tokenize(comment.lower())
bleu = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=SmoothingFunction().method1)
bleu = round(bleu, 2)
else:
bleu = "N/A (No reference provided)"
return comment, bleu
# Gradio UI
iface = gr.Interface(
fn=generate_comment,
inputs=[
gr.Textbox(label="Enter Code Snippet", lines=4, placeholder="Paste your code here..."),
gr.Textbox(label="Reference Comment (optional)", placeholder="Expected comment to compare BLEU score"),
],
outputs=[
gr.Textbox(label="Generated Comment"),
gr.Textbox(label="BLEU Score"),
],
title="Code Comment Generator using CodeT5",
description="Paste code and get a generated comment with BLEU score (optional reference)."
)
iface.launch()