File size: 2,266 Bytes
cf23c39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import solara

import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from huggingface_hub import snapshot_download
from umap import UMAP
from annoy import AnnoyIndex
from cluestar import plot_text

news = pd.read_csv('https://raw.githubusercontent.com/alonsosilvaallende/fake-and-real-news-titles/main/example.csv')
texts = list(news["title"].values)
texts = [str(text) for text in texts if str(text) != 'nan']

sentences = ["This is an example sentence", "Each sentence is converted"]
model_path = snapshot_download(
    repo_id="TaylorAI/gte-tiny", allow_patterns=["*.json", "pytorch_model.bin"]
)

embedder2 = SentenceTransformer(model_path)
embeddings2 = [embedder2.encode(str(texts[i])) for i in range(500)]

reducer = UMAP()
X2 = reducer.fit_transform(embeddings2)

f = len(embeddings2[0])
t = AnnoyIndex(f, 'angular')
for i, embedded_text in enumerate(embeddings2):
    t.add_item(i, embedded_text)
t.build(1000)

query = solara.reactive("What did Nancy Pelosi said about Obamacare?")
@solara.component
def Page():
    with solara.Column(margin=10):
        solara.Markdown("#Embeddings")
        solara.InputText("Enter some query:", query, continuous_update=True)
        if query.value != "":
            embedded_query = embedder2.encode(query.value)
            idx, distances = t.get_nns_by_vector(embedded_query, 10, include_distances=True)
            df_neighbors = pd.DataFrame()
            df_neighbors["neighbors"]=[texts[i] for i in idx]
            df_neighbors["distances"] = distances
            x = reducer.transform([embedded_query])
            color_array = ["texts" if i not in idx else "neighbors" for i in range(len(texts[:500]))]+["query"]
            solara.AltairChart(plot_text(np.vstack((X2,x)), texts[:500]+[query.value], color_array=color_array).configure_range(
        category=['#0000ff', '#ff0000', '#a0aab4']
    ))
            solara.DataFrame(df_neighbors, items_per_page=10)
            solara.Markdown("Dataset: 'Fake and real news' from [kaggle](https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset)")
        else:
            color_array = ["texts" for _ in range(500)]
            solara.AltairChart(plot_text(X2, texts[:500], color_array=color_array))