Paarth commited on
Commit
2fbd2b9
·
1 Parent(s): 323b84a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -9,15 +9,18 @@ from pytorch_lightning.callbacks import ModelCheckpoint
9
  from pytorch_lightning.loggers import TensorBoardLogger
10
  from datasets.dataset_dict import DatasetDict
11
  from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast
 
 
 
 
12
  import warnings
13
  warnings.simplefilter('ignore')
14
 
15
- from summarizer import SummarizerModel
16
- from transformers import AutoTokenizer
17
  MODEL_NAME = 'Salesforce/codet5-base-multi-sum'
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
  model = SummarizerModel(MODEL_NAME)
20
- model.load_state_dict(torch.load('codet5-base-1_epoch-val_loss-0.80.pth'))
 
21
 
22
  def summarize(text: str,
23
  tokenizer = tokenizer,
@@ -26,7 +29,7 @@ def summarize(text: str,
26
  Summarizes a given code in text format.
27
  Args:
28
  text: The code in string format that needs to be summarized.
29
- tokenizer: The tokenizer used in the trained T5 model.
30
  trained_model: A SummarizerModel fine-tuned instance of
31
  T5 model family.
32
  """
@@ -53,9 +56,20 @@ def summarize(text: str,
53
  for gen_id in generated_ids]
54
  return "".join(preds)
55
 
 
 
 
 
 
 
 
 
56
  outputs = gr.outputs.Textbox()
57
- iface = gr.Interface(fn=summarize,
58
- inputs=['text'],
59
- outputs=outputs,
60
- description="Demo for ForgeT5 | Input: A python code | Output: The code summarization")
61
- iface.launch(inline = False)
 
 
 
 
9
  from pytorch_lightning.loggers import TensorBoardLogger
10
  from datasets.dataset_dict import DatasetDict
11
  from transformers import AdamW, T5ForConditionalGeneration, T5TokenizerFast
12
+ from tqdm.auto import tqdm
13
+ from models.summarizer import SummarizerModel
14
+ from transformers import AutoTokenizer
15
+ from sentence_transformers import SentenceTransformer
16
  import warnings
17
  warnings.simplefilter('ignore')
18
 
 
 
19
  MODEL_NAME = 'Salesforce/codet5-base-multi-sum'
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
  model = SummarizerModel(MODEL_NAME)
22
+ model.load_state_dict(torch.load('/content/drive/MyDrive/PlageBERT/Models/codet5-base-1_epoch-val_loss-0.80.pth'))
23
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
24
 
25
  def summarize(text: str,
26
  tokenizer = tokenizer,
 
29
  Summarizes a given code in text format.
30
  Args:
31
  text: The code in string format that needs to be summarized.
32
+ tokenizer: The tokeniszer used in the trained T5 model.
33
  trained_model: A SummarizerModel fine-tuned instance of
34
  T5 model family.
35
  """
 
56
  for gen_id in generated_ids]
57
  return "".join(preds)
58
 
59
+ def find_similarity_score(code_1, code_2, model = embedding_model):
60
+ summary_code_1 = summarize(text = code_1)
61
+ summary_code_2 = summarize(text = code_2)
62
+ embedding_1 = model.encode(summary_code_1)
63
+ embedding_2 = model.encode(summary_code_2)
64
+ score = np.dot(embedding_1, embedding_2)/(np.linalg.norm(embedding_1) * np.linalg.norm(embedding_2))
65
+ return summary_code_1, summary_code_2, round(score, 2)
66
+
67
  outputs = gr.outputs.Textbox()
68
+ iface = gr.Interface(fn=find_similarity_score,
69
+ inputs=[gr.Textbox(label = 'First Code snippet'),
70
+ gr.Textbox(label = 'Second Code snippet')],
71
+ outputs=[gr.Textbox(label = 'Summary of first Code snippet'),
72
+ gr.Textbox(label = 'Summary of second Code snippet'),
73
+ gr.Textbox(label = 'The similarity score')],
74
+ description='The similarity score')
75
+ iface.launch()