|
import torch |
|
from sentence_transformers import SentenceTransformer, util |
|
import pandas as pd |
|
import gradio as gr |
|
|
|
def save_embeddings(sentences, filename): |
|
embeddings = model.encode(sentences, convert_to_tensor=True) |
|
torch.save(embeddings, filename) |
|
|
|
def load_embeddings(filename): |
|
return torch.load(filename) |
|
|
|
def preprocess_model_descriptions(file_path): |
|
encodings = ['utf-8', 'latin-1', 'utf-16'] |
|
for encoding in encodings: |
|
try: |
|
df = pd.read_csv(file_path, sep='\t', header=None, names=['model_id', 'description']) |
|
df.dropna(subset=['description'], inplace=True) |
|
model_ids = df['model_id'].tolist() |
|
descriptions = df['description'].tolist() |
|
break |
|
except UnicodeDecodeError: |
|
continue |
|
else: |
|
raise UnicodeDecodeError("Unable to decode the file using the available encodings.") |
|
|
|
return model_ids, descriptions |
|
|
|
def perform_similarity_search(query_embeddings, embeddings, model_ids, descriptions, top_k=10): |
|
cosine_scores = util.cos_sim(query_embeddings, embeddings) |
|
similarity_scores = cosine_scores.tolist() |
|
|
|
results = [] |
|
for i, query_embedding in enumerate(query_embeddings): |
|
query_results = sorted(zip(model_ids, descriptions, similarity_scores[i]), key=lambda x: x[2], reverse=True)[:top_k] |
|
results.extend(query_results) |
|
|
|
return pd.DataFrame(results, columns=["model_id", "description", "score"]) |
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
model_ids, descriptions = preprocess_model_descriptions('model_descriptions.tsv') |
|
embeddings = load_embeddings('embeddings_model_descriptions.pt') |
|
|
|
with gr.Blocks() as demo: |
|
input = gr.Textbox(label="Enter your query") |
|
button = gr.Button(label="Search") |
|
|
|
df_output = gr.Dataframe(label="Similarity Results", wrap=True) |
|
|
|
def search(query): |
|
query_embedding = model.encode([query], convert_to_tensor=True) |
|
return perform_similarity_search(query_embedding, embeddings, model_ids, descriptions) |
|
|
|
input.submit(search, inputs=input, outputs=df_output) |
|
button.click(search, inputs=input, outputs=df_output) |
|
|
|
demo.launch() |
|
|