File size: 5,229 Bytes
48847bb
 
 
 
 
 
 
 
ed03b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48847bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed03b25
 
 
 
48847bb
 
ed03b25
 
48847bb
 
ed03b25
48847bb
 
 
 
 
 
ed03b25
48847bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed03b25
 
48847bb
 
 
 
ed03b25
 
 
 
 
 
 
 
 
 
 
 
48847bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import requests
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM
import torch


MAX_SOURCE_LENGTH = 512


def prepare_models():
    tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer")

    tokenizer.special_dict = {
        f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
    }
    tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
    tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
    tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
    tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
    tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
    tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
    tokenizer.add_id = tokenizer.get_vocab()["<add>"]
    tokenizer.del_id = tokenizer.get_vocab()["<del>"]
    tokenizer.start_id = tokenizer.get_vocab()["<start>"]
    tokenizer.end_id = tokenizer.get_vocab()["<end>"]

    model = AutoModelForSeq2SeqLM.from_pretrained("microsoft/codereviewer")

    model.eval()
    return tokenizer, model


def pad_assert(tokenizer, source_ids):
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids


def encode_diff(tokenizer, diff, msg, source):
    difflines = diff.split("\n")[1:]  # remove start @@
    difflines = [line for line in difflines if len(line.strip()) > 0]
    map_dic = {"-": 0, "+": 1, " ": 2}

    def f(s):
        if s in map_dic:
            return map_dic[s]
        else:
            return 2

    labels = [f(line[0]) for line in difflines]
    difflines = [line[1:].strip() for line in difflines]
    inputstr = "<s>" + source + "</s>"
    inputstr += "<msg>" + msg
    for label, line in zip(labels, difflines):
        if label == 1:
            inputstr += "<add>" + line
        elif label == 0:
            inputstr += "<del>" + line
        else:
            inputstr += "<keep>" + line
    source_ids = tokenizer.encode(inputstr, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    source_ids = pad_assert(tokenizer, source_ids)
    return source_ids


class FileDiffs(object):
    def __init__(self, diff_string):
        diff_array = diff_string.split("\n")
        self.file_name = diff_array[0]
        self.file_path = self.file_name.split("a/", 1)[1].rsplit("b/", 1)[0]
        self.diffs = list()
        for line in diff_array[4:]:
            if line.startswith("@@"):
                self.diffs.append(str())
            self.diffs[-1] += "\n" + line


def review_commit(user="p4vv37", repository="ueflow", commit="610a8c7b02b946bc9e5e26e6dacbba0e2abba259"):
    tokenizer, model = prepare_models()

    # Get diff and commit metadata from GitHub API
    commit_metadata = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}").json()
    msg = commit_metadata["commit"]["message"]
    diff_data = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}",
                             headers={"Accept": "application/vnd.github.diff"})
    code_diff = diff_data.text

    # Parse diff into FileDiffs objects
    files_diffs = list()
    for file in code_diff.split("diff --git"):
        if len(file) > 0:
            fd = FileDiffs(file)
            files_diffs.append(fd)

    # Generate comments for each diff
    output = ""
    for fd in files_diffs:
        output += F"File:{fd.file_path}\n"
        source = requests.get(F"https://raw.githubusercontent.com/{user}/{repository}/^{commit}/{fd.file_path}").text

        for diff in fd.diffs:
            inputs = torch.tensor([encode_diff(tokenizer, diff, msg, source)], dtype=torch.long).to("cpu")
            inputs_mask = inputs.ne(tokenizer.pad_id)
            preds = model.generate(inputs,
                                   attention_mask=inputs_mask,
                                   use_cache=True,
                                   num_beams=5,
                                   early_stopping=True,
                                   max_length=100,
                                   num_return_sequences=2
                                   )
            preds = list(preds.cpu().numpy())
            pred_nls = [tokenizer.decode(_id[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
                        for _id in preds]
            output += diff + "\n#######\nComment:\n#######\n" + pred_nls[0] + "\n#######\n"
    return output


description = "An interface for running " \
              "\"Microsoft CodeBERT CodeReviewer: Pre-Training for Automating Code Review Activities.\" " \
              "(microsoft/codereviewer) on GitHub commits."
examples = [
    ["p4vv37", "ueflow", "610a8c7b02b946bc9e5e26e6dacbba0e2abba259"],
    ["microsoft", "vscode", "378b0d711f6b82ac59b47fb246906043a6fb995a"],
]
iface = gr.Interface(fn=review_commit,
                     description=description,
                     inputs=["text", "text", "text"],
                     outputs="text",
                     examples=examples)
iface.launch()