Spaces:
Runtime error
Runtime error
File size: 2,326 Bytes
568499b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import os
from typing import Dict, List
import ast
import pandas as pd
import sentence_transformers
import streamlit as st
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config
def truncate_description(description, length=50):
return " ".join(description.split()[:length])
def get_repos_with_descriptions(repos_df, repos):
return repos_df.loc[repos]
def search_f(
retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
query: str,
k: int,
description_length: int,
doc_col: List[str],
):
results = retrieval_pipe.find_similar(query, k)
# results['repo'] = results.index
results["link"] = "https://github.com/" + results["repo"]
for col in doc_col:
results[col] = results[col].apply(
lambda desc: truncate_description(desc, description_length)
)
shown_cols = ["repo", "tasks", "link", "distance"]
shown_cols = shown_cols + doc_col
return results.reset_index(drop=True)[shown_cols]
def merge_text_list_cols(retrieval_df, text_list_cols):
retrieval_df = retrieval_df.copy()
for col in text_list_cols:
retrieval_df[col] = retrieval_df[col].apply(
lambda t: " ".join(ast.literal_eval(t))
)
return retrieval_df
def setup_pipeline(
extractor: feature_extractors.SentenceEncoderFeatureExtractor,
documents_df: pd.DataFrame,
text_col: str,
):
retrieval_pipeline.RetrievalPipelineFactory.build(
documents_df[text_col], metadata=documents_df
)
@st.cache(allow_output_mutation=True)
def setup_retrieval_pipeline(
query_encoder_path, document_encoder_path, documents, metadata
):
document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
)
query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
)
retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
feature_extractor=document_encoder,
query_feature_extractor=query_encoder,
index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
)
return retrieval_pipe.build(documents, metadata=metadata)
|