|
import solara |
|
|
|
import numpy as np |
|
import pandas as pd |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import snapshot_download |
|
from umap import UMAP |
|
from annoy import AnnoyIndex |
|
from cluestar import plot_text |
|
|
|
news = pd.read_csv('https://raw.githubusercontent.com/alonsosilvaallende/fake-and-real-news-titles/main/example.csv') |
|
texts = list(news["title"].values) |
|
texts = [str(text) for text in texts if str(text) != 'nan'] |
|
|
|
sentences = ["This is an example sentence", "Each sentence is converted"] |
|
model_path = snapshot_download( |
|
repo_id="TaylorAI/gte-tiny", allow_patterns=["*.json", "pytorch_model.bin"] |
|
) |
|
|
|
embedder2 = SentenceTransformer(model_path) |
|
embeddings2 = [embedder2.encode(str(texts[i])) for i in range(500)] |
|
|
|
reducer = UMAP() |
|
X2 = reducer.fit_transform(embeddings2) |
|
|
|
f = len(embeddings2[0]) |
|
t = AnnoyIndex(f, 'angular') |
|
for i, embedded_text in enumerate(embeddings2): |
|
t.add_item(i, embedded_text) |
|
t.build(1000) |
|
|
|
query = solara.reactive("What did Nancy Pelosi said about Obamacare?") |
|
@solara.component |
|
def Page(): |
|
with solara.Column(margin=10): |
|
solara.Markdown("#Embeddings") |
|
solara.InputText("Enter some query:", query, continuous_update=True) |
|
if query.value != "": |
|
embedded_query = embedder2.encode(query.value) |
|
idx, distances = t.get_nns_by_vector(embedded_query, 10, include_distances=True) |
|
df_neighbors = pd.DataFrame() |
|
df_neighbors["neighbors"]=[texts[i] for i in idx] |
|
df_neighbors["distances"] = distances |
|
x = reducer.transform([embedded_query]) |
|
color_array = ["texts" if i not in idx else "neighbors" for i in range(len(texts[:500]))]+["query"] |
|
solara.AltairChart(plot_text(np.vstack((X2,x)), texts[:500]+[query.value], color_array=color_array).configure_range( |
|
category=['#0000ff', '#ff0000', '#a0aab4'] |
|
)) |
|
solara.DataFrame(df_neighbors, items_per_page=10) |
|
solara.Markdown("Dataset: 'Fake and real news' from [kaggle](https://www.kaggle.com/datasets/clmentbisaillon/fake-and-real-news-dataset)") |
|
else: |
|
color_array = ["texts" for _ in range(500)] |
|
solara.AltairChart(plot_text(X2, texts[:500], color_array=color_array)) |
|
|