Spaces:
Runtime error
Runtime error
import os | |
from typing import Dict, List | |
import ast | |
import pandas as pd | |
import sentence_transformers | |
import streamlit as st | |
from findkit import feature_extractors, indexes, retrieval_pipeline | |
from toolz import partial | |
import config | |
def truncate_description(description, length=50): | |
return " ".join(description.split()[:length]) | |
def get_repos_with_descriptions(repos_df, repos): | |
return repos_df.loc[repos] | |
def search_f( | |
retrieval_pipe: retrieval_pipeline.RetrievalPipeline, | |
query: str, | |
k: int, | |
description_length: int, | |
doc_col: List[str], | |
): | |
results = retrieval_pipe.find_similar(query, k) | |
# results['repo'] = results.index | |
results["link"] = "https://github.com/" + results["repo"] | |
for col in doc_col: | |
results[col] = results[col].apply( | |
lambda desc: truncate_description(desc, description_length) | |
) | |
shown_cols = ["repo", "tasks", "link", "distance"] | |
shown_cols = shown_cols + doc_col | |
return results.reset_index(drop=True)[shown_cols] | |
def merge_text_list_cols(retrieval_df, text_list_cols): | |
retrieval_df = retrieval_df.copy() | |
for col in text_list_cols: | |
retrieval_df[col] = retrieval_df[col].apply( | |
lambda t: " ".join(ast.literal_eval(t)) | |
) | |
return retrieval_df | |
def setup_pipeline( | |
extractor: feature_extractors.SentenceEncoderFeatureExtractor, | |
documents_df: pd.DataFrame, | |
text_col: str, | |
): | |
retrieval_pipeline.RetrievalPipelineFactory.build( | |
documents_df[text_col], metadata=documents_df | |
) | |
def setup_retrieval_pipeline( | |
query_encoder_path, document_encoder_path, documents, metadata | |
): | |
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu") | |
) | |
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor( | |
sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu") | |
) | |
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory( | |
feature_extractor=document_encoder, | |
query_feature_extractor=query_encoder, | |
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"), | |
) | |
return retrieval_pipe.build(documents, metadata=metadata) | |