File size: 2,446 Bytes
4f08d2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Union
import os
def delete_file(file_pt: Path) -> None:
try:
file_pt.unlink()
except FileNotFoundError:
pass
def full_path(inp_dir_or_path: str) -> Path:
"""Returns full path"""
return Path(inp_dir_or_path).expanduser().resolve()
def mkdir_p(inp_dir_or_path: Union[str, Path]) -> Path:
"""Give a file/dir path, makes sure that all the directories exists"""
inp_dir_or_path = full_path(inp_dir_or_path)
if inp_dir_or_path.suffix: # file
inp_dir_or_path.parent.mkdir(parents=True, exist_ok=True)
else: # dir
inp_dir_or_path.mkdir(parents=True, exist_ok=True)
return inp_dir_or_path
def similarity_between_sent(sent1_encoded, sent2_encoded):
"""report the avg. cosine similarity score b.w two pairs of sentences"""
similarity_scores = []
for i in range(len(sent1_encoded)):
similarity_scores.append(cosine_similarity(
sent1_encoded[i], sent2_encoded[i]))
return np.mean(similarity_scores),similarity_scores
def cosine_similarity(a, b):
"""
Takes 2 vectors a, b and returns the cosine similarity according
to the definition of the dot product
"""
dot_product = np.dot(a, b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
return dot_product / (norm_a * norm_b)
def load_data(path):
if path.endswith(".csv"):
data=pd.read_csv(path)
else:
data=pd.read_csv(path,sep="\t")
if not isinstance(data,pd.DataFrame):
raise ValueError("Data should be in pandas DataFrame format")
return data
def read_data(dataset):
if dataset == "mrpc":
data = load_data("/home/yash/EMNLP-2024/data/mrpc.csv")
data = data.copy()
elif dataset == "qqp":
data = load_data("/home/yash/EMNLP-2024/data/qoura.csv")
data = data.copy().dropna()
# handling irregularities in columns names
data.columns = data.columns.str.strip()
data = data.rename(columns={"is_duplicate":"label",'question1':"sentence1","question2":"sentence2"})
elif dataset in ["paws","paw","wiki"]:
path = "/home/yash/EMNLP-2024/data/paw_wiki.tsv"
data = load_data(path)
data = data.copy()
else:
ValueError("No dataset found.")
return data |