|
|
|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torch.nn import init, MarginRankingLoss |
|
from torch.optim import Adam |
|
from distutils.version import LooseVersion |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.autograd import Variable |
|
import math |
|
from transformers import AutoConfig, AutoModel, AutoTokenizer |
|
import nltk |
|
import re |
|
import torch.optim as optim |
|
from tqdm import tqdm |
|
from transformers import AutoModelForMaskedLM |
|
import torch.nn.functional as F |
|
import random |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def greet(X, ny): |
|
global eng_dict |
|
ny = int(ny) |
|
if ny == 0: |
|
rand_no = random.random() |
|
tok_map = {2: 0.4363429005892416, |
|
1: 0.6672580202327398, |
|
4: 0.7476060740459144, |
|
3: 0.9618703668504087, |
|
6: 0.9701028532809564, |
|
7: 0.9729244545819342, |
|
8: 0.9739508754144756, |
|
5: 0.9994508859743607, |
|
9: 0.9997507867114407, |
|
10: 0.9999112969650892, |
|
11: 0.9999788802297832, |
|
0: 0.9999831041838266, |
|
12: 0.9999873281378701, |
|
22: 0.9999957760459568, |
|
14: 1.0000000000000002} |
|
for key in tok_map.keys(): |
|
if rand_no < tok_map[key]: |
|
num_sub_tokens_label = key |
|
break |
|
else: |
|
num_sub_tokens_label = ny |
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") |
|
model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base") |
|
model.load_state_dict(torch.load('model_26_2')) |
|
model.eval() |
|
X_init = X |
|
X_init = X_init.replace("[MASK]", " [MASK] ") |
|
X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label)) |
|
tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt') |
|
input_id_chunki = tokens['input_ids'][0].split(510) |
|
input_id_chunks = [] |
|
mask_chunks = [] |
|
mask_chunki = tokens['attention_mask'][0].split(510) |
|
for tensor in input_id_chunki: |
|
input_id_chunks.append(tensor) |
|
for tensor in mask_chunki: |
|
mask_chunks.append(tensor) |
|
xi = torch.full((1,), fill_value=101) |
|
yi = torch.full((1,), fill_value=1) |
|
zi = torch.full((1,), fill_value=102) |
|
for r in range(len(input_id_chunks)): |
|
input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1) |
|
input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1) |
|
mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1) |
|
mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1) |
|
di = torch.full((1,), fill_value=0) |
|
for i in range(len(input_id_chunks)): |
|
pad_len = 512 - input_id_chunks[i].shape[0] |
|
if pad_len > 0: |
|
for p in range(pad_len): |
|
input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1) |
|
mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1) |
|
vb = torch.ones_like(input_id_chunks[0]) |
|
fg = torch.zeros_like(input_id_chunks[0]) |
|
maski = [] |
|
for l in range(len(input_id_chunks)): |
|
masked_pos = [] |
|
for i in range(len(input_id_chunks[l])): |
|
if input_id_chunks[l][i] == tokenizer.mask_token_id: |
|
if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id: |
|
continue |
|
masked_pos.append(i) |
|
maski.append(masked_pos) |
|
input_ids = torch.stack(input_id_chunks) |
|
att_mask = torch.stack(mask_chunks) |
|
outputs = model(input_ids, attention_mask = att_mask) |
|
last_hidden_state = outputs[0].squeeze() |
|
l_o_l_sa = [] |
|
sum_state = [] |
|
for t in range(num_sub_tokens_label): |
|
c = [] |
|
l_o_l_sa.append(c) |
|
if len(maski) == 1: |
|
masked_pos = maski[0] |
|
for k in masked_pos: |
|
for t in range(num_sub_tokens_label): |
|
l_o_l_sa[t].append(last_hidden_state[k+t]) |
|
else: |
|
for p in range(len(maski)): |
|
masked_pos = maski[p] |
|
for k in masked_pos: |
|
for t in range(num_sub_tokens_label): |
|
if (k+t) >= len(last_hidden_state[p]): |
|
l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])]) |
|
continue |
|
l_o_l_sa[t].append(last_hidden_state[p][k+t]) |
|
for t in range(num_sub_tokens_label): |
|
sum_state.append(l_o_l_sa[t][0]) |
|
for i in range(len(l_o_l_sa[0])): |
|
if i == 0: |
|
continue |
|
for t in range(num_sub_tokens_label): |
|
sum_state[t] = sum_state[t] + l_o_l_sa[t][i] |
|
yip = len(l_o_l_sa[0]) |
|
|
|
er = "" |
|
for t in range(num_sub_tokens_label): |
|
sum_state[t] /= yip |
|
idx = torch.topk(sum_state[t], k=5, dim=0)[1] |
|
wor = [tokenizer.decode(i.item()).strip() for i in idx] |
|
for kl in wor: |
|
if all(char.isalpha() for char in kl): |
|
|
|
er+=kl |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return er |
|
title = "Rename a variable in a Java class" |
|
description = """This model is a fine-tuned GraphCodeBERT model fin-tuned to output higher-quality variable names for Java classes. Long classes are handled by the |
|
model. Replace any variable name with a "[MASK]" to get an identifier renaming. |
|
""" |
|
ex = ["""import java.io.*; |
|
public class x { |
|
public static void main(String[] args) { |
|
String f = "file.txt"; |
|
BufferedReader [MASK] = null; |
|
String l; |
|
try { |
|
[MASK] = new BufferedReader(new FileReader(f)); |
|
while ((l = [MASK].readLine()) != null) { |
|
System.out.println(l); |
|
} |
|
} catch (IOException e) { |
|
e.printStackTrace(); |
|
} finally { |
|
try { |
|
if ([MASK] != null) [MASK].close(); |
|
} catch (IOException ex) { |
|
ex.printStackTrace(); |
|
} |
|
} |
|
} |
|
}""", """import java.net.*; |
|
import java.io.*; |
|
|
|
public class s { |
|
public static void main(String[] args) throws IOException { |
|
ServerSocket [MASK] = new ServerSocket(8000); |
|
try { |
|
Socket s = [MASK].accept(); |
|
PrintWriter pw = new PrintWriter(s.getOutputStream(), true); |
|
BufferedReader br = new BufferedReader(new InputStreamReader(s.getInputStream())); |
|
String i; |
|
while ((i = br.readLine()) != null) { |
|
pw.println(i); |
|
} |
|
} finally { |
|
if ([MASK] != null) [MASK].close(); |
|
} |
|
} |
|
}""", """import java.io.*; |
|
import java.util.*; |
|
|
|
public class y { |
|
public static void main(String[] args) { |
|
String [MASK] = "data.csv"; |
|
String l = ""; |
|
String cvsSplitBy = ","; |
|
try (BufferedReader br = new BufferedReader(new FileReader([MASK]))) { |
|
while ((l = br.readLine()) != null) { |
|
String[] z = l.split(cvsSplitBy); |
|
System.out.println("Values [field-1= " + z[0] + " , field-2=" + z[1] + "]"); |
|
} |
|
} catch (IOException e) { |
|
e.printStackTrace(); |
|
} |
|
} |
|
}"""] |
|
|
|
textbox = gr.Textbox(title=title, |
|
description=description,examples = ex,label="Type Java code snippet:", placeholder="replace variable with [MASK]", lines=10) |
|
|
|
gr.Interface(fn=greet, inputs=[ |
|
textbox, |
|
gr.Textbox(type="text", label="Number of tokens in name:", placeholder="0 for randomly sampled number of tokens") |
|
], outputs="text").launch() |
|
|
|
|
|
|
|
|
|
|
|
import java.io.*; |
|
public class x { |
|
public static void main(String[] args) { |
|
String f = "file.txt"; |
|
BufferedReader [MASK] = null; |
|
String l; |
|
try { |
|
[MASK] = new BufferedReader(new FileReader(f)); |
|
while ((l = [MASK].readLine()) != null) { |
|
System.out.println(l); |
|
} |
|
} catch (IOException e) { |
|
e.printStackTrace(); |
|
} finally { |
|
try { |
|
if ([MASK] != null) [MASK].close(); |
|
} catch (IOException ex) { |
|
ex.printStackTrace(); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|