Spaces:
Running
Running
File size: 5,229 Bytes
48847bb ed03b25 48847bb ed03b25 48847bb ed03b25 48847bb ed03b25 48847bb ed03b25 48847bb ed03b25 48847bb ed03b25 48847bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import requests
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import torch
MAX_SOURCE_LENGTH = 512
def prepare_models():
tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer")
tokenizer.special_dict = {
f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
}
tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
tokenizer.add_id = tokenizer.get_vocab()["<add>"]
tokenizer.del_id = tokenizer.get_vocab()["<del>"]
tokenizer.start_id = tokenizer.get_vocab()["<start>"]
tokenizer.end_id = tokenizer.get_vocab()["<end>"]
model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/codereviewer")
model.eval()
return tokenizer, model
def pad_assert(tokenizer, source_ids):
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
def encode_diff(tokenizer, diff, msg, source):
difflines = diff.split("\n")[1:] # remove start @@
difflines = [line for line in difflines if len(line.strip()) > 0]
map_dic = {"-": 0, "+": 1, " ": 2}
def f(s):
if s in map_dic:
return map_dic[s]
else:
return 2
labels = [f(line[0]) for line in difflines]
difflines = [line[1:].strip() for line in difflines]
inputstr = "<s>" + source + "</s>"
inputstr += "<msg>" + msg
for label, line in zip(labels, difflines):
if label == 1:
inputstr += "<add>" + line
elif label == 0:
inputstr += "<del>" + line
else:
inputstr += "<keep>" + line
source_ids = tokenizer.encode(inputstr, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
source_ids = pad_assert(tokenizer, source_ids)
return source_ids
class FileDiffs(object):
def __init__(self, diff_string):
diff_array = diff_string.split("\n")
self.file_name = diff_array[0]
self.file_path = self.file_name.split("a/", 1)[1].rsplit("b/", 1)[0]
self.diffs = list()
for line in diff_array[4:]:
if line.startswith("@@"):
self.diffs.append(str())
self.diffs[-1] += "\n" + line
def review_commit(user="p4vv37", repository="ueflow", commit="610a8c7b02b946bc9e5e26e6dacbba0e2abba259"):
tokenizer, model = prepare_models()
# Get diff and commit metadata from GitHub API
commit_metadata = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}").json()
msg = commit_metadata["commit"]["message"]
diff_data = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}",
headers={"Accept": "application/vnd.github.diff"})
code_diff = diff_data.text
# Parse diff into FileDiffs objects
files_diffs = list()
for file in code_diff.split("diff --git"):
if len(file) > 0:
fd = FileDiffs(file)
files_diffs.append(fd)
# Generate comments for each diff
output = ""
for fd in files_diffs:
output += F"File:{fd.file_path}\n"
source = requests.get(F"https://raw.githubusercontent.com/{user}/{repository}/^{commit}/{fd.file_path}").text
for diff in fd.diffs:
inputs = torch.tensor([encode_diff(tokenizer, diff, msg, source)], dtype=torch.long).to("cpu")
inputs_mask = inputs.ne(tokenizer.pad_id)
preds = model.generate(inputs,
attention_mask=inputs_mask,
use_cache=True,
num_beams=5,
early_stopping=True,
max_length=100,
num_return_sequences=2
)
preds = list(preds.cpu().numpy())
pred_nls = [tokenizer.decode(_id[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
for _id in preds]
output += diff + "\n#######\nComment:\n#######\n" + pred_nls[0] + "\n#######\n"
return output
description = "An interface for running " \
"\"Microsoft CodeBERT CodeReviewer: Pre-Training for Automating Code Review Activities.\" " \
"(microsoft/codereviewer) on GitHub commits."
examples = [
["p4vv37", "ueflow", "610a8c7b02b946bc9e5e26e6dacbba0e2abba259"],
["microsoft", "vscode", "378b0d711f6b82ac59b47fb246906043a6fb995a"],
]
iface = gr.Interface(fn=review_commit,
description=description,
inputs=["text", "text", "text"],
outputs="text",
examples=examples)
iface.launch()
|