|
import time |
|
import pandas as pd |
|
import polars as pl |
|
import torch |
|
import logging |
|
from datasets import Dataset |
|
from sentence_transformers import SentenceTransformer |
|
from typing import Optional |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def sts(modelname: str, data1: str, data2: str, score: float) -> Optional[pl.DataFrame]: |
|
""" |
|
Calculate semantic textual similarity between two sets of sentences. |
|
|
|
Args: |
|
modelname: Name of the model to use |
|
data1: Path to first input CSV file |
|
data2: Path to second input CSV file |
|
score: Minimum similarity score threshold |
|
|
|
Returns: |
|
Optional[pl.DataFrame]: DataFrame with similarity results or None if error occurs |
|
""" |
|
try: |
|
st = time.time() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Using device: {device}") |
|
|
|
model = SentenceTransformer( |
|
modelname, |
|
device=device, |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
sentences1 = Dataset.from_pandas(pd.read_csv(data1, on_bad_lines='skip', header=0, sep="\t")) |
|
sentences2 = Dataset.from_pandas(pd.read_csv(data2, on_bad_lines='skip', header=0, sep="\t")) |
|
|
|
if sentences1.num_rows == 0 or sentences2.num_rows == 0: |
|
logger.error("Empty input data found") |
|
return None |
|
|
|
|
|
logger.info("Generating embeddings for first set...") |
|
embeddings1 = model.encode( |
|
sentences1["text"], |
|
normalize_embeddings=True, |
|
batch_size=1024, |
|
show_progress_bar=True |
|
) |
|
|
|
logger.info("Generating embeddings for second set...") |
|
embeddings2 = model.encode( |
|
sentences2["text"], |
|
normalize_embeddings=True, |
|
batch_size=1024, |
|
show_progress_bar=True |
|
) |
|
|
|
|
|
logger.info("Calculating similarity matrix...") |
|
similarity_matrix = model.similarity(embeddings1, embeddings2) |
|
|
|
|
|
df_pd = pd.DataFrame(similarity_matrix) |
|
dfi = df_pd.__dataframe__() |
|
df = pl.from_dataframe(dfi) |
|
|
|
|
|
df_matrix_with_index = df.with_row_index(name="row_index").with_columns( |
|
pl.col("row_index").cast(pl.UInt64) |
|
) |
|
df_long = df_matrix_with_index.unpivot( |
|
index="row_index", |
|
variable_name="column_index", |
|
value_name="score" |
|
).with_columns(pl.col("column_index").cast(pl.UInt64)) |
|
|
|
|
|
df_sentences1 = pl.DataFrame(sentences1.to_pandas()).with_row_index(name="row_index").with_columns( |
|
pl.col("row_index").cast(pl.UInt64) |
|
) |
|
df_sentences2 = pl.DataFrame(sentences2.to_pandas()).with_row_index(name="column_index").with_columns( |
|
pl.col("column_index").cast(pl.UInt64) |
|
) |
|
|
|
|
|
df_long = (df_long |
|
.with_columns([pl.col("score").round(4).cast(pl.Float32)]) |
|
.join(df_sentences1, on="row_index") |
|
.join(df_sentences2, on="column_index")) |
|
|
|
df_long = df_long.rename({ |
|
"text": "sentences1", |
|
"text_right": "sentences2", |
|
}).drop(["row_index", "column_index"]) |
|
|
|
|
|
result_df = df_long.filter(pl.col("score") > score).sort(["score"], descending=True) |
|
|
|
elapsed_time = time.time() - st |
|
logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}') |
|
logger.info(f'Found {len(result_df)} pairs above score threshold {score}') |
|
|
|
return result_df |
|
|
|
except Exception as e: |
|
logger.error(f"Error in STS process: {str(e)}") |
|
return None |
|
|