File size: 2,326 Bytes
568499b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from typing import Dict, List


import ast
import pandas as pd
import sentence_transformers
import streamlit as st
from findkit import feature_extractors, indexes, retrieval_pipeline
from toolz import partial
import config


def truncate_description(description, length=50):
    return " ".join(description.split()[:length])


def get_repos_with_descriptions(repos_df, repos):
    return repos_df.loc[repos]


def search_f(
    retrieval_pipe: retrieval_pipeline.RetrievalPipeline,
    query: str,
    k: int,
    description_length: int,
    doc_col: List[str],
):
    results = retrieval_pipe.find_similar(query, k)
    # results['repo'] = results.index
    results["link"] = "https://github.com/" + results["repo"]
    for col in doc_col:
        results[col] = results[col].apply(
            lambda desc: truncate_description(desc, description_length)
        )
    shown_cols = ["repo", "tasks", "link", "distance"]
    shown_cols = shown_cols + doc_col
    return results.reset_index(drop=True)[shown_cols]


def merge_text_list_cols(retrieval_df, text_list_cols):
    retrieval_df = retrieval_df.copy()
    for col in text_list_cols:
        retrieval_df[col] = retrieval_df[col].apply(
            lambda t: " ".join(ast.literal_eval(t))
        )
    return retrieval_df


def setup_pipeline(
    extractor: feature_extractors.SentenceEncoderFeatureExtractor,
    documents_df: pd.DataFrame,
    text_col: str,
):
    retrieval_pipeline.RetrievalPipelineFactory.build(
        documents_df[text_col], metadata=documents_df
    )


@st.cache(allow_output_mutation=True)
def setup_retrieval_pipeline(
    query_encoder_path, document_encoder_path, documents, metadata
):
    document_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
        sentence_transformers.SentenceTransformer(document_encoder_path, device="cpu")
    )
    query_encoder = feature_extractors.SentenceEncoderFeatureExtractor(
        sentence_transformers.SentenceTransformer(query_encoder_path, device="cpu")
    )
    retrieval_pipe = retrieval_pipeline.RetrievalPipelineFactory(
        feature_extractor=document_encoder,
        query_feature_extractor=query_encoder,
        index_factory=partial(indexes.NMSLIBIndex.build, distance="cosinesimil"),
    )
    return retrieval_pipe.build(documents, metadata=metadata)