albertmartinez's picture
Upgrade gradio
3ce1088
import time
import pandas as pd
import polars as pl
import torch
import logging
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from typing import Optional
logger = logging.getLogger(__name__)
def sts(modelname: str, data1: str, data2: str, score: float) -> Optional[pl.DataFrame]:
"""
Calculate semantic textual similarity between two sets of sentences.
Args:
modelname: Name of the model to use
data1: Path to first input CSV file
data2: Path to second input CSV file
score: Minimum similarity score threshold
Returns:
Optional[pl.DataFrame]: DataFrame with similarity results or None if error occurs
"""
try:
st = time.time()
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
model = SentenceTransformer(
modelname,
device=device,
trust_remote_code=True,
)
# Read and validate input data
sentences1 = Dataset.from_pandas(pd.read_csv(data1, on_bad_lines='skip', header=0, sep="\t"))
sentences2 = Dataset.from_pandas(pd.read_csv(data2, on_bad_lines='skip', header=0, sep="\t"))
if sentences1.num_rows == 0 or sentences2.num_rows == 0:
logger.error("Empty input data found")
return None
# Generate embeddings
logger.info("Generating embeddings for first set...")
embeddings1 = model.encode(
sentences1["text"],
normalize_embeddings=True,
batch_size=1024,
show_progress_bar=True
)
logger.info("Generating embeddings for second set...")
embeddings2 = model.encode(
sentences2["text"],
normalize_embeddings=True,
batch_size=1024,
show_progress_bar=True
)
# Calculate similarity matrix
logger.info("Calculating similarity matrix...")
similarity_matrix = model.similarity(embeddings1, embeddings2)
# Process results
df_pd = pd.DataFrame(similarity_matrix)
dfi = df_pd.__dataframe__()
df = pl.from_dataframe(dfi)
# Transform matrix to long format
df_matrix_with_index = df.with_row_index(name="row_index").with_columns(
pl.col("row_index").cast(pl.UInt64)
)
df_long = df_matrix_with_index.unpivot(
index="row_index",
variable_name="column_index",
value_name="score"
).with_columns(pl.col("column_index").cast(pl.UInt64))
# Join with original text
df_sentences1 = pl.DataFrame(sentences1.to_pandas()).with_row_index(name="row_index").with_columns(
pl.col("row_index").cast(pl.UInt64)
)
df_sentences2 = pl.DataFrame(sentences2.to_pandas()).with_row_index(name="column_index").with_columns(
pl.col("column_index").cast(pl.UInt64)
)
# Process final results
df_long = (df_long
.with_columns([pl.col("score").round(4).cast(pl.Float32)])
.join(df_sentences1, on="row_index")
.join(df_sentences2, on="column_index"))
df_long = df_long.rename({
"text": "sentences1",
"text_right": "sentences2",
}).drop(["row_index", "column_index"])
# Filter and sort results
result_df = df_long.filter(pl.col("score") > score).sort(["score"], descending=True)
elapsed_time = time.time() - st
logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}')
logger.info(f'Found {len(result_df)} pairs above score threshold {score}')
return result_df
except Exception as e:
logger.error(f"Error in STS process: {str(e)}")
return None