|
import pandas as pd |
|
import streamlit as st |
|
import datasets |
|
import plotly.express as px |
|
from sentence_transformers import SentenceTransformer |
|
from PIL import Image |
|
import os |
|
from pandas.api.types import ( |
|
is_categorical_dtype, |
|
is_datetime64_any_dtype, |
|
is_numeric_dtype, |
|
is_object_dtype, |
|
) |
|
import subprocess |
|
from tempfile import NamedTemporaryFile |
|
from itertools import combinations |
|
import networkx as nx |
|
import plotly.graph_objects as go |
|
import colorcet as cc |
|
from matplotlib.colors import rgb2hex |
|
from sklearn.cluster import KMeans |
|
from sklearn.decomposition import PCA |
|
import hdbscan |
|
import umap |
|
import numpy as np |
|
from bokeh.plotting import figure |
|
from bokeh.models import ColumnDataSource |
|
from datetime import datetime |
|
|
|
|
|
|
|
model_dir = "./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1" |
|
|
|
@st.cache_data(show_spinner=True) |
|
def download_models(): |
|
|
|
subprocess.run(["mkdir", "models"]) |
|
subprocess.run(["wget", "--no-check-certificate", "https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/clip-ViT-B-32-multilingual-v1.zip"], check=True) |
|
subprocess.run(["unzip", "-q", "clip-ViT-B-32-multilingual-v1.zip", "-d", model_dir], check=True) |
|
|
|
token_ = st.secrets["token"] |
|
|
|
@st.cache_data(show_spinner=True) |
|
def load_dataset(): |
|
dataset = datasets.load_dataset('rjadr/ditaduranuncamais', split='train', use_auth_token=token_) |
|
dataset.add_faiss_index(column="txt_embs") |
|
dataset.add_faiss_index(column="img_embs") |
|
dataset = dataset.remove_columns(['Post Created Date', 'Post Created Time','Like and View Counts Disabled','Link','Download URL','Views']) |
|
return dataset |
|
|
|
@st.cache_data(show_spinner=False) |
|
def load_dataframe(_dataset): |
|
dataframe = _dataset.remove_columns(['txt_embs', 'img_embs']).to_pandas() |
|
|
|
dataframe['Hashtags'] = dataframe.apply(lambda row: f"{row['Description']} {row['Image Text']}", axis=1) |
|
dataframe['Hashtags'] = dataframe['Hashtags'].str.lower().str.findall(r'#(\w+)').apply(set) |
|
|
|
|
|
|
|
dataframe = dataframe[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in dataframe.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] |
|
return dataframe |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_img_model(): |
|
|
|
return SentenceTransformer('clip-ViT-B-32') |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def load_txt_model(): |
|
|
|
|
|
return SentenceTransformer('./models/sbert.net_models_sentence-transformers_clip-ViT-B-32-multilingual-v1') |
|
|
|
def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: |
|
""" |
|
Adds a UI on top of a dataframe to let viewers filter columns |
|
Args: |
|
df (pd.DataFrame): Original dataframe |
|
Returns: |
|
pd.DataFrame: Filtered dataframe |
|
""" |
|
modify = st.checkbox("Add filters") |
|
|
|
if not modify: |
|
return df |
|
|
|
df = df.copy() |
|
|
|
|
|
for col in df.columns: |
|
if is_object_dtype(df[col]): |
|
try: |
|
df[col] = pd.to_datetime(df[col]) |
|
except Exception: |
|
pass |
|
|
|
if is_datetime64_any_dtype(df[col]): |
|
df[col] = df[col].dt.tz_localize(None) |
|
|
|
modification_container = st.container() |
|
|
|
with modification_container: |
|
to_filter_columns = st.multiselect("Filter dataframe on", df.columns) |
|
for column in to_filter_columns: |
|
left, right = st.columns((1, 20)) |
|
left.write("↳") |
|
|
|
if is_categorical_dtype(df[column]) or df[column].nunique() < 10: |
|
user_cat_input = right.multiselect( |
|
f"Values for {column}", |
|
df[column].unique(), |
|
default=list(df[column].unique()), |
|
) |
|
df = df[df[column].isin(user_cat_input)] |
|
elif is_numeric_dtype(df[column]): |
|
_min = float(df[column].min()) |
|
_max = float(df[column].max()) |
|
step = (_max - _min) / 100 |
|
user_num_input = right.slider( |
|
f"Values for {column}", |
|
_min, |
|
_max, |
|
(_min, _max), |
|
step=step, |
|
) |
|
df = df[df[column].between(*user_num_input)] |
|
elif is_datetime64_any_dtype(df[column]): |
|
user_date_input = right.date_input( |
|
f"Values for {column}", |
|
value=( |
|
df[column].min(), |
|
df[column].max(), |
|
), |
|
) |
|
if len(user_date_input) == 2: |
|
user_date_input = tuple(map(pd.to_datetime, user_date_input)) |
|
start_date, end_date = user_date_input |
|
df = df.loc[df[column].between(start_date, end_date)] |
|
else: |
|
user_text_input = right.text_input( |
|
f"Substring or regex in {column}", |
|
) |
|
if user_text_input: |
|
df = df[df[column].str.contains(user_text_input)] |
|
|
|
return df |
|
|
|
@st.cache_data |
|
def get_image_embs(image): |
|
""" |
|
Get image embeddings |
|
Parameters: |
|
uploaded_file (PIL.Image): Uploaded image file |
|
Returns: |
|
img_emb (np.array): Image embeddings |
|
""" |
|
img_emb = image_model.encode(Image.open(image)) |
|
return img_emb |
|
|
|
@st.cache_data(show_spinner=False) |
|
def get_text_embs(text): |
|
""" |
|
Get text embeddings |
|
Parameters: |
|
text (str): Text to encode |
|
Returns: |
|
text_emb (np.array): Text embeddings |
|
""" |
|
txt_emb = text_model.encode(text) |
|
return txt_emb |
|
|
|
@st.cache_data |
|
def postprocess_results(scores, samples): |
|
""" |
|
Postprocess results to tuple of labels and scores |
|
Parameters: |
|
scores (np.array): Scores |
|
samples (datasets.Dataset): Samples |
|
Returns: |
|
labels (list): List of tuples of PIL images and labels/scores |
|
""" |
|
samples_df = pd.DataFrame.from_dict(samples) |
|
samples_df["score"] = scores |
|
samples_df["score"] = (1 - (samples_df["score"] - samples_df["score"].min()) / ( |
|
samples_df["score"].max() - samples_df["score"].min())) * 100 |
|
samples_df["score"] = samples_df["score"].astype(int) |
|
samples_df.reset_index(inplace=True, drop=True) |
|
samples_df = samples_df[['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name'] + [col for col in samples_df.columns if col not in ['Post Created', 'image', 'Description', 'Image Text', 'Account', 'User Name']]] |
|
return samples_df.drop(columns=['txt_embs', 'img_embs']) |
|
|
|
@st.cache_data |
|
def text_to_text(text, k=5): |
|
""" |
|
Text to text |
|
Parameters: |
|
text (str): Input text |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
text_emb = get_text_embs(text) |
|
scores, samples = dataset.get_nearest_examples('txt_embs', text_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def image_to_text(image, k=5): |
|
""" |
|
Image to text |
|
Parameters: |
|
image (str): Temp filepath to image |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
img_emb = get_image_embs(image.name) |
|
scores, samples = dataset.get_nearest_examples('txt_embs', img_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def text_to_image(text, k=5): |
|
""" |
|
Text to image |
|
Parameters: |
|
text (str): Input text |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
text_emb = get_text_embs(text) |
|
scores, samples = dataset.get_nearest_examples('img_embs', text_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
@st.cache_data |
|
def image_to_image(image, k=5): |
|
""" |
|
Image to image |
|
Parameters: |
|
image (str): Temp filepath to image |
|
k (int): Number of top results to return |
|
Returns: |
|
results (list): List of tuples of PIL images and labels/scores |
|
""" |
|
img_emb = get_image_embs(image.name) |
|
scores, samples = dataset.get_nearest_examples('img_embs', img_emb, k=k) |
|
return postprocess_results(scores, samples) |
|
|
|
def disparity_filter(g: nx.Graph, weight: str = 'weight', alpha: float = 0.05) -> nx.Graph: |
|
""" |
|
Computes the backbone of the input graph using the disparity filter algorithm. |
|
|
|
The algorithm is proposed in: |
|
M. A. Serrano, M. Boguna, and A. Vespignani, |
|
"Extracting the Multiscale Backbone of Complex Weighted Networks", |
|
PNAS, 106(16), pp 6483--6488 (2009). |
|
DOI: 10.1073/pnas.0808904106 |
|
|
|
Implementation taken from https://groups.google.com/g/networkx-discuss/c/bCuHZ3qQ2po/m/QvUUJqOYDbIJ |
|
|
|
Parameters |
|
---------- |
|
g : NetworkX graph |
|
The input graph. |
|
weight : str, optional (default='weight') |
|
The name of the edge attribute to use as weight. |
|
alpha : float, optional (default=0.05) |
|
The statistical significance level for the disparity filter (p-value). |
|
|
|
Returns |
|
------- |
|
backbone_graph : NetworkX graph |
|
The backbone graph. |
|
""" |
|
|
|
backbone_graph = nx.Graph() |
|
|
|
|
|
for node in g: |
|
|
|
k_n = len(g[node]) |
|
|
|
|
|
if k_n > 1: |
|
|
|
sum_w = sum(g[node][neighbor][weight] for neighbor in g[node]) |
|
|
|
|
|
for neighbor in g[node]: |
|
|
|
edge_weight = g[node][neighbor][weight] |
|
|
|
|
|
pij = float(edge_weight) / sum_w |
|
|
|
|
|
if (1 - pij) ** (k_n - 1) < alpha: |
|
backbone_graph.add_edge(node, neighbor, weight=edge_weight) |
|
|
|
|
|
return backbone_graph |
|
|
|
st.cache_data(show_spinner=True) |
|
def assign_community_colors(G: nx.Graph, attr: str = 'community') -> dict: |
|
""" |
|
Assigns a unique color to each community in the input graph. |
|
|
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
attr : str, optional |
|
The node attribute of the community names or indexes (default is 'community'). |
|
|
|
Returns |
|
------- |
|
dict |
|
A dictionary mapping each community to a unique color. |
|
""" |
|
glasbey_colors = cc.glasbey_hv |
|
communities_ = set(nx.get_node_attributes(G, attr).values()) |
|
return {community: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, community in enumerate(communities_)} |
|
|
|
st.cache_data(show_spinner=True) |
|
def generate_hover_text(G: nx.Graph, attr: str = 'community') -> list: |
|
""" |
|
Generates hover text for each node in the input graph. |
|
|
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
attr : str, optional |
|
The node attribute of the community names or indexes (default is 'community'). |
|
|
|
Returns |
|
------- |
|
list |
|
A list of strings containing the hover text for each node. |
|
""" |
|
return [f"Node: {str(node)}<br>Community: {G.nodes[node][attr] + 1}<br># of connections: {len(adjacencies)}" for node, adjacencies in G.adjacency()] |
|
|
|
st.cache_data(show_spinner=True) |
|
def calculate_node_sizes(G: nx.Graph) -> list: |
|
""" |
|
Calculates the size of each node in the input graph based on its degree. |
|
|
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
|
|
Returns |
|
------- |
|
list |
|
A list of node sizes. |
|
""" |
|
degrees = dict(G.degree()) |
|
max_degree = max(deg for node, deg in degrees.items()) |
|
return [10 + 20 * (degrees[node] / max_degree) for node in G.nodes()] |
|
|
|
@st.cache_data(show_spinner=True) |
|
def plot_graph(_G: nx.Graph, layout: str = "fdp", community_names_lookup: dict = None): |
|
""" |
|
Plots a network graph with communities. |
|
|
|
Parameters |
|
---------- |
|
G : nx.Graph |
|
The input graph. |
|
layout : str, optional |
|
The layout algorithm to use (default is "fdp"). |
|
""" |
|
pos = nx.spring_layout(G_backbone, dim=3, seed=779) |
|
community_colors = assign_community_colors(_G) |
|
node_colors = [community_colors[_G.nodes[n]['community']] for n in _G.nodes] |
|
|
|
edge_trace = go.Scatter(x=[item for sublist in [[pos[edge[0]][0], pos[edge[1]][0], None] for edge in _G.edges()] for item in sublist], |
|
y=[item for sublist in [[pos[edge[0]][1], pos[edge[1]][1], None] for edge in _G.edges()] for item in sublist], |
|
line=dict(width=0.5, color='#888'), |
|
hoverinfo='none', |
|
mode='lines') |
|
|
|
node_trace = go.Scatter(x=[pos[n][0] for n in _G.nodes()], |
|
y=[pos[n][1] for n in _G.nodes()], |
|
mode='markers', |
|
hoverinfo='text', |
|
marker=dict(color=node_colors, size=10, line_width=2)) |
|
|
|
node_trace.text = generate_hover_text(_G) |
|
node_trace.marker.size = calculate_node_sizes(_G) |
|
|
|
fig = go.Figure(data=[edge_trace, node_trace], |
|
layout=go.Layout(title='Network graph with communities', |
|
titlefont=dict(size=16), |
|
showlegend=False, |
|
hovermode='closest', |
|
margin=dict(b=20,l=5,r=5,t=40), |
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
|
height=800)) |
|
|
|
|
|
Xn=[pos[k][0] for k in G_backbone.nodes()] |
|
Yn=[pos[k][1] for k in G_backbone.nodes()] |
|
Zn=[pos[k][2] for k in G_backbone.nodes()] |
|
|
|
|
|
Xe=[] |
|
Ye=[] |
|
Ze=[] |
|
for e in G_backbone.edges(): |
|
Xe+=[pos[e[0]][0],pos[e[1]][0], None] |
|
Ye+=[pos[e[0]][1],pos[e[1]][1], None] |
|
Ze+=[pos[e[0]][2],pos[e[1]][2], None] |
|
|
|
|
|
trace1=go.Scatter3d(x=Xe, |
|
y=Ye, |
|
z=Ze, |
|
mode='lines', |
|
line=dict(color='rgb(125,125,125)', width=1), |
|
hoverinfo='none' |
|
) |
|
|
|
|
|
community_names = {i: community_names_lookup[f"Community {i+1}"] for i in range(len(communities))} |
|
|
|
|
|
hover_text = [f"{node} ({community_names[G_backbone.nodes[node]['community']]})" for node in G_backbone.nodes()] |
|
|
|
trace2=go.Scatter3d(x=Xn, |
|
y=Yn, |
|
z=Zn, |
|
mode='markers', |
|
name='actors', |
|
marker=dict(symbol='circle', |
|
size=7, |
|
color=node_colors, |
|
line=dict(color='rgb(50,50,50)', width=0.2) |
|
), |
|
text=hover_text, |
|
hoverinfo='text' |
|
) |
|
|
|
axis=dict(showbackground=False, |
|
showline=False, |
|
zeroline=False, |
|
showgrid=False, |
|
showticklabels=False, |
|
title='' |
|
) |
|
|
|
layout = go.Layout( |
|
title="3D Network Graph", |
|
width=1000, |
|
height=1000, |
|
showlegend=False, |
|
scene=dict( |
|
xaxis=dict(axis), |
|
yaxis=dict(axis), |
|
zaxis=dict(axis), |
|
), |
|
margin=dict( |
|
t=100 |
|
), |
|
hovermode='closest', |
|
) |
|
|
|
data=[trace1, trace2] |
|
fig=go.Figure(data=data, layout=layout) |
|
return fig |
|
|
|
@st.cache_data(show_spinner=True) |
|
def cluster_embeddings(embeddings, clustering_algo='KMeans', dim_reduction='PCA', n_clusters=5, min_cluster_size=5, n_components=2, n_neighbors=15, min_dist=0.0, random_state=42, min_samples=5): |
|
""" |
|
A function to cluster embeddings. |
|
|
|
Args: |
|
embeddings (pd.Series): A series of numpy vectors. |
|
clustering_algo (str): The clustering algorithm to use. Either 'KMeans' or 'HDBSCAN'. |
|
dim_reduction (str): The dimensionality reduction method to use. Either 'PCA' or 'UMAP'. |
|
n_clusters (int): The number of clusters for KMeans. |
|
min_cluster_size (int): The minimum cluster size for HDBSCAN. |
|
n_components (int): The number of components for the dimensionality reduction method. |
|
n_neighbors (int): The number of neighbors for UMAP. |
|
min_dist (float): The minimum distance for UMAP. |
|
random_state (int): The seed used by the random number generator. |
|
min_samples (int): The minimum number of samples for HDBSCAN. |
|
|
|
Returns: |
|
pd.Series: A series of cluster labels. |
|
""" |
|
|
|
|
|
if dim_reduction == 'PCA': |
|
reducer = PCA(n_components=n_components, random_state=random_state) |
|
elif dim_reduction == 'UMAP': |
|
reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=n_components, random_state=random_state) |
|
else: |
|
raise ValueError('Invalid dimensionality reduction method') |
|
|
|
reduced_embeddings = reducer.fit_transform(np.stack(embeddings)) |
|
|
|
|
|
if clustering_algo == 'KMeans': |
|
clusterer = KMeans(n_clusters=n_clusters, random_state=random_state) |
|
elif clustering_algo == 'HDBSCAN': |
|
clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) |
|
else: |
|
raise ValueError('Invalid clustering algorithm') |
|
|
|
labels = clusterer.fit_predict(reduced_embeddings) |
|
|
|
return labels, reduced_embeddings |
|
|
|
st.title("#ditaduranuncamais Data Explorer") |
|
|
|
def check_password(): |
|
"""Returns `True` if the user had the correct password.""" |
|
|
|
def password_entered(): |
|
"""Checks whether a password entered by the user is correct.""" |
|
if st.session_state["password"] == st.secrets["password"]: |
|
st.session_state["password_correct"] = True |
|
del st.session_state["password"] |
|
else: |
|
st.session_state["password_correct"] = False |
|
|
|
if "password_correct" not in st.session_state: |
|
|
|
st.text_input( |
|
"Password", type="password", on_change=password_entered, key="password" |
|
) |
|
return False |
|
elif not st.session_state["password_correct"]: |
|
|
|
st.text_input( |
|
"Password", type="password", on_change=password_entered, key="password" |
|
) |
|
st.error("😕 Password incorrect") |
|
return False |
|
else: |
|
|
|
return True |
|
|
|
if not check_password(): |
|
st.stop() |
|
|
|
|
|
if not os.path.exists(model_dir): |
|
download_models() |
|
|
|
dataset = load_dataset() |
|
df = load_dataframe(dataset) |
|
image_model = load_img_model() |
|
text_model = load_txt_model() |
|
|
|
menu_options = ["Data exploration", "Semantic search", "Hashtags", "Clustering", "Stats"] |
|
|
|
st.sidebar.markdown('# Menu') |
|
selected_menu_option = st.sidebar.radio("Select a page", menu_options) |
|
|
|
if selected_menu_option == "Data exploration": |
|
st.dataframe( |
|
data=filter_dataframe(df), |
|
|
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_menu_option == "Semantic search": |
|
tabs = ["Text to Text", "Text to Image", "Image to Image", "Image to Text"] |
|
selected_tab = st.sidebar.radio("Select a search type", tabs) |
|
|
|
if selected_tab == "Text to Text": |
|
st.markdown('## Text to text search') |
|
text_to_text_input = st.text_input("Enter text") |
|
text_to_text_k_top = st.slider("Number of results", 1, 500, 20) |
|
if st.button("Search"): |
|
if not text_to_text_input: |
|
st.warning("Please enter text") |
|
else: |
|
st.dataframe( |
|
data=text_to_text(text_to_text_input, text_to_text_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Text to Image": |
|
st.markdown('## Text to image search') |
|
text_to_image_input = st.text_input("Enter text") |
|
text_to_image_k_top = st.slider("Number of results", 1, 500, 20) |
|
if st.button("Search"): |
|
if not text_to_image_input: |
|
st.warning("Please enter some text") |
|
else: |
|
st.dataframe( |
|
data=text_to_image(text_to_image_input, text_to_image_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Image to Image": |
|
st.markdown('## Image to image search') |
|
image_to_image_k_top = st.slider("Number of results", 1, 500, 20) |
|
image_to_image_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
temp_file = NamedTemporaryFile(delete=False) |
|
if st.button("Search"): |
|
if not image_to_image_input: |
|
st.warning("Please upload an image") |
|
else: |
|
temp_file.write(image_to_image_input.getvalue()) |
|
|
|
st.dataframe( |
|
data=image_to_image(temp_file, image_to_image_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
elif selected_tab == "Image to Text": |
|
st.markdown('## Image to text search') |
|
image_to_text_k_top = st.slider("Number of results", 1, 500, 20) |
|
image_to_text_input = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) |
|
temp_file = NamedTemporaryFile(delete=False) |
|
if st.button("Search"): |
|
if not image_to_text_input: |
|
st.warning("Please upload an image") |
|
else: |
|
temp_file.write(image_to_text_input.getvalue()) |
|
st.dataframe( |
|
data=image_to_text(temp_file, image_to_text_k_top), |
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
elif selected_menu_option == "Hashtags": |
|
if 'dfx' not in st.session_state: |
|
st.session_state.dfx = df.copy() |
|
|
|
all_hashtags = list(set([item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist])) |
|
|
|
st.sidebar.markdown('# Hashtag co-occurrence analysis options') |
|
|
|
hashtags_to_remove = st.sidebar.multiselect("Hashtags to remove", all_hashtags) |
|
|
|
col1, col2 = st.sidebar.columns(2) |
|
|
|
if col1.button("Remove hashtags"): |
|
|
|
st.session_state.dfx['Hashtags'] = st.session_state.dfx['Hashtags'].apply(lambda x: [item for item in x if item not in hashtags_to_remove]) |
|
|
|
|
|
if col2.button("Reset"): |
|
st.session_state.dfx = df.copy() |
|
|
|
|
|
hashtags = [item for sublist in st.session_state.dfx['Hashtags'].tolist() for item in sublist] |
|
|
|
hashtag_freq = st.session_state.dfx.explode('Hashtags').groupby('Hashtags').size().reset_index(name='counts') |
|
|
|
hashtag_freq = hashtag_freq.sort_values(by='counts', ascending=False) |
|
|
|
|
|
hashtags_fig = px.scatter(hashtag_freq, x='Hashtags', y='counts', log_y=True, |
|
labels={'Hashtags': 'Hashtags', 'counts': 'Frequency'}, |
|
title='Frequency of hashtags in #throwbackthursday posts on Instagram', |
|
height=600) |
|
st.markdown("### Hashtag Frequency Distribution") |
|
st.markdown('Here we apply hashtag co-occurence analysis for mnemonic community detection. This detects communities through creating a network of hashtag pairs (which hashtags are used together in which posts) and then applying community detection algorithms on this network.') |
|
st.plotly_chart(hashtags_fig) |
|
|
|
weight_option = st.sidebar.radio( |
|
'Select weight definition', |
|
('Number of users that use the hashtag pairs', 'Total number of occurrences') |
|
) |
|
|
|
hashtag_user_pairs = [(tuple(sorted(combination)), userid) for hashtags, userid in zip(st.session_state.dfx['Hashtags'], st.session_state.dfx['User Name']) for combination in combinations(hashtags, r=2)] |
|
|
|
hashtag_user_df = pd.DataFrame(hashtag_user_pairs, columns=['hashtag_pair', 'User Name']) |
|
if weight_option == 'Number of users that use the hashtag pairs': |
|
|
|
hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').agg({'User Name': 'nunique'}).reset_index() |
|
elif weight_option == 'Total number of occurrences': |
|
|
|
hashtag_user_df = hashtag_user_df.groupby('hashtag_pair').size().reset_index(name='User Name') |
|
|
|
edge_list = hashtag_user_df.rename(columns={'hashtag_pair': 'hashtag1', 'User Name': 'weight'}) |
|
edge_list[['hashtag1', 'hashtag2']] = pd.DataFrame(edge_list['hashtag1'].tolist(), index=edge_list.index) |
|
edge_list = edge_list[['hashtag1', 'hashtag2', 'weight']] |
|
|
|
st.markdown("### Edge List of Hashtag Pairs") |
|
|
|
G = nx.from_pandas_edgelist(edge_list, 'hashtag1', 'hashtag2', 'weight') |
|
G_backbone = disparity_filter(G, weight='weight', alpha=0.05) |
|
st.markdown(f'Number of nodes {len(G_backbone.nodes)}') |
|
st.markdown(f'Number of edges {len(G_backbone.edges)}') |
|
st.dataframe(edge_list.sort_values(by='weight', ascending=False).head(10).style.set_caption("Edge list of hashtag pairs with the highest weight")) |
|
|
|
|
|
communities = nx.community.louvain_communities(G_backbone, weight='weight', seed=1234) |
|
communities = list(communities) |
|
|
|
|
|
communities.sort(key=len, reverse=True) |
|
|
|
for i, community in enumerate(communities): |
|
for node in community: |
|
G_backbone.nodes[node]['community'] = i |
|
|
|
|
|
sorted_community_hashtags = [ |
|
[ |
|
hashtag |
|
for hashtag, degree in sorted( |
|
((h, G.degree(h, weight='weight')) for h in community), |
|
key=lambda x: x[1], |
|
reverse=True |
|
) |
|
] |
|
for community in communities |
|
] |
|
|
|
|
|
sorted_community_hashtags = pd.DataFrame(sorted_community_hashtags).T |
|
|
|
|
|
sorted_community_hashtags.columns = [f'Community {i+1}' for i in range(len(sorted_community_hashtags.columns))] |
|
|
|
st.markdown("### Hashtag Communities") |
|
st.markdown(f'There are {len(communities)} communities in the graph.') |
|
st.dataframe(sorted_community_hashtags) |
|
|
|
|
|
st.markdown("### Community Names") |
|
st.markdown("Edit the names of the communities in the table below so they show up in the visualisations.") |
|
|
|
df_community_names = pd.DataFrame(sorted_community_hashtags.columns, columns=['community_names'], index=sorted_community_hashtags.columns) |
|
df_community_names = st.data_editor(df_community_names) |
|
|
|
st.download_button( |
|
label="Download community names as csv", |
|
data=df_community_names.to_csv().encode("utf-8"), |
|
file_name="community_names.csv", |
|
mime="text/csv", |
|
) |
|
|
|
|
|
community_names_lookup = df_community_names['community_names'].to_dict() |
|
|
|
|
|
st.markdown("### Community Size Over Time") |
|
st.markdown("Select communites to see their size over time.") |
|
|
|
selected_communities = st.multiselect('Select Communities', community_names_lookup.values(), default=community_names_lookup.values()) |
|
|
|
|
|
resample_dict = { |
|
'Day': 'D', |
|
'Three Days': '3D', |
|
'Week': 'W', |
|
'Two Weeks': '2W', |
|
'Month': 'M', |
|
'Quarter': 'Q', |
|
'Year': 'Y' |
|
} |
|
|
|
|
|
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys()), index=4) |
|
|
|
df_communities = st.session_state.dfx.copy() |
|
|
|
def community_dict(communities): |
|
community_dict = {} |
|
for i, community in enumerate(communities): |
|
for node in community: |
|
community_dict[node] = community_names_lookup[f'Community {i+1}'] |
|
return community_dict |
|
|
|
community_dict = community_dict(communities) |
|
|
|
df_communities['Communities'] = df_communities['Hashtags'].apply(lambda x: [community_dict[tag] for tag in x if tag in community_dict.keys()]) |
|
|
|
df_communities = df_communities[['Post Created', 'Communities']].explode('Communities') |
|
df_communities = df_communities.dropna(subset=['Communities']) |
|
|
|
|
|
min_date = df_communities['Post Created'].min().date() |
|
max_date = df_communities['Post Created'].max().date() |
|
|
|
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) |
|
|
|
|
|
df_communities = df_communities[(df_communities['Post Created'].dt.date >= date_range[0]) & (df_communities['Post Created'].dt.date <= date_range[1])] |
|
|
|
|
|
df_communities['Post Created'] = df_communities['Post Created'].dt.to_period(resample_dict[resample_time]) |
|
df_community_sizes = df_communities.groupby(['Post Created', 'Communities']).size().unstack(fill_value=0) |
|
df_community_sizes.index = df_community_sizes.index.to_timestamp() |
|
|
|
df_community_sizes = df_community_sizes[selected_communities] |
|
|
|
st.plotly_chart(px.line(df_community_sizes, title='Community Size Over Time', labels={'value': 'Number of posts', 'index': 'Date', 'variable': 'Community'})) |
|
|
|
st.markdown("### Hashtag Network Graph") |
|
st.plotly_chart(plot_graph(G_backbone, layout="fdp", community_names_lookup=community_names_lookup)) |
|
|
|
|
|
elif selected_menu_option == "Clustering": |
|
st.markdown("## Clustering") |
|
st.markdown("Select the type of embeddings to cluster and the clustering algorithm and dimensionality reduction method to use in the sidebar. Then click run clustering. Clustering may take some time.") |
|
st.sidebar.markdown("# Clustering Options") |
|
type_embeddings = st.sidebar.selectbox("Type of embeddings to cluster", ["Text", "Image"]) |
|
clustering_algo = st.sidebar.selectbox("Clustering algorithm", ["HDBSCAN", "KMeans"]) |
|
dim_reduction = st.sidebar.selectbox("Dimensionality reduction method", ["UMAP", "PCA"]) |
|
if clustering_algo == "KMeans": |
|
st.sidebar.markdown("### KMeans Options") |
|
n_clusters = st.sidebar.slider("Number of clusters", 2, 20, 5) |
|
min_cluster_size = None |
|
min_samples = None |
|
elif clustering_algo == "HDBSCAN": |
|
st.sidebar.markdown("### HDBSCAN Options") |
|
min_cluster_size = st.sidebar.slider("[Minimum cluster size](https://hdbscan.readthedocs.io/en/latest/parameter_selection.html#selecting-min-cluster-size)", 2, 200, 5) |
|
min_samples = st.sidebar.slider("[Minimum samples](https://hdbscan.readthedocs.io/en/latest/parameter_selection.html#selecting-min-samples)", 2, 50, 5) |
|
n_clusters = None |
|
if dim_reduction == "UMAP": |
|
st.sidebar.markdown("### UMAP Options") |
|
n_components = st.sidebar.slider("[Number of components](https://umap-learn.readthedocs.io/en/latest/parameters.html#n-components)", 2, 80, 50) |
|
n_neighbors = st.sidebar.slider("[Number of neighbors](https://umap-learn.readthedocs.io/en/latest/parameters.html#n-neighbors)", 2, 20, 15) |
|
min_dist = st.sidebar.slider("[Minimum distance](https://umap-learn.readthedocs.io/en/latest/parameters.html#min-dist)", 0.0, 1.0, 0.0) |
|
else: |
|
st.sidebar.markdown("### PCA Options") |
|
n_components = st.sidebar.slider("[Number of components](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html)", 2, 80, 2) |
|
n_neighbors = None |
|
min_dist = None |
|
|
|
st.markdown("### Clustering Results") |
|
if type_embeddings == "Text": |
|
embeddings = dataset['txt_embs'] |
|
elif type_embeddings == "Image": |
|
embeddings = dataset['img_embs'] |
|
|
|
|
|
labels, reduced_embeddings = cluster_embeddings(embeddings, clustering_algo=clustering_algo, dim_reduction=dim_reduction, n_clusters=n_clusters, min_cluster_size=min_cluster_size, n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist) |
|
st.markdown(f"Clustering {type_embeddings} embeddings using {clustering_algo} with {dim_reduction} dimensionality reduction method resulting in **{len(set(labels))}** clusters.") |
|
|
|
df_clustered = df.copy() |
|
df_clustered['cluster'] = labels |
|
df_clustered = df_clustered.set_index('cluster').reset_index() |
|
st.dataframe( |
|
data=filter_dataframe(df_clustered), |
|
|
|
column_config={ |
|
"image": st.column_config.ImageColumn( |
|
"Image", help="Instagram image" |
|
), |
|
"URL": st.column_config.LinkColumn( |
|
"Link", help="Instagram link", width="small" |
|
) |
|
}, |
|
hide_index=True, |
|
) |
|
|
|
st.download_button( |
|
"Download dataset with labels", |
|
df_clustered.to_csv(index=False).encode('utf-8'), |
|
f'ditaduranuncamais_{datetime.now().strftime("%Y%m%d-%H%M%S")}.csv', |
|
"text/csv", |
|
key='download-csv' |
|
) |
|
|
|
st.markdown("### Cluster Plot") |
|
|
|
if n_components > 2: |
|
reducer = umap.UMAP(n_components=2, random_state=42) |
|
reduced_embeddings = reducer.fit_transform(reduced_embeddings) |
|
|
|
|
|
|
|
descriptions = df_clustered['Description'].tolist() |
|
images = df_clustered['image'].tolist() |
|
glasbey_colors = cc.glasbey_hv |
|
color_dict = {n: rgb2hex(glasbey_colors[i % len(glasbey_colors)]) for i, n in enumerate(set(labels))} |
|
colors = [color_dict[label] for label in labels] |
|
|
|
source = ColumnDataSource(data=dict( |
|
x=reduced_embeddings[:, 0], |
|
y=reduced_embeddings[:, 1], |
|
desc=descriptions, |
|
imgs=images, |
|
colors=colors |
|
)) |
|
|
|
TOOLTIPS = """ |
|
<div> |
|
<div> |
|
<img |
|
src="@imgs" height="100" alt="@imgs" width="100" |
|
style="float: left; margin: 0px 15px 15px 0px;" |
|
border="2" |
|
></img> |
|
</div> |
|
<div> |
|
<span style="font-size: 12px; font-weight: bold;">@desc</span> |
|
</div> |
|
</div> |
|
""" |
|
|
|
p = figure(width=800, height=800, tooltips=TOOLTIPS, |
|
title="Mouse over the dots") |
|
|
|
p.circle('x', 'y', size=10, source=source, color='colors', line_color=None) |
|
st.bokeh_chart(p) |
|
|
|
|
|
st.markdown("### Cluster Size") |
|
cluster_sizes = df_clustered.groupby('cluster').size().reset_index(name='counts') |
|
cluster_sizes = cluster_sizes.sort_values(by='counts', ascending=False) |
|
cluster_sizes = cluster_sizes[cluster_sizes['cluster'] != -1] |
|
cluster_sizes = cluster_sizes.set_index('cluster').reset_index() |
|
cluster_sizes = cluster_sizes.rename(columns={'cluster': 'Cluster', 'counts': 'Size'}) |
|
st.dataframe(cluster_sizes) |
|
|
|
st.markdown("### Cluster Time Series") |
|
|
|
|
|
variable = st.selectbox('Select Variable', ['Likes', 'Comments', 'Followers at Posting', 'Total Interactions']) |
|
|
|
|
|
resample_dict = { |
|
'Day': 'D', |
|
'Three Days': '3D', |
|
'Week': 'W', |
|
'Two Weeks': '2W', |
|
'Month': 'M', |
|
'Quarter': 'Q', |
|
'Year': 'Y' |
|
} |
|
|
|
|
|
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) |
|
|
|
|
|
min_date = df_clustered['Post Created'].min().date() |
|
max_date = df_clustered['Post Created'].max().date() |
|
|
|
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) |
|
|
|
|
|
df_resampled = df_clustered[(df_clustered['Post Created'].dt.date >= date_range[0]) & (df_clustered['Post Created'].dt.date <= date_range[1])] |
|
df_resampled = df_resampled.set_index('Post Created') |
|
|
|
|
|
cluster_sizes = df_resampled[df_resampled['cluster'] != -1]['cluster'].value_counts() |
|
clusters = cluster_sizes.index |
|
|
|
|
|
default_clusters = cluster_sizes.sort_values(ascending=False).head(5).index.tolist() |
|
|
|
|
|
selected_clusters = st.multiselect('Select Clusters', options=clusters.tolist(), default=default_clusters) |
|
|
|
|
|
df_plot = pd.DataFrame() |
|
|
|
|
|
for cluster in selected_clusters: |
|
|
|
df_cluster = df_resampled[df_resampled['cluster'] == cluster][variable].resample(resample_dict[resample_time]).sum() |
|
df_plot = pd.concat([df_plot, df_cluster], axis=1) |
|
|
|
|
|
df_plot.columns = selected_clusters |
|
|
|
|
|
st.line_chart(df_plot) |
|
|
|
|
|
|
|
elif selected_menu_option == "Stats": |
|
st.markdown("### Time Series Analysis") |
|
|
|
variable = st.selectbox('Select Variable', ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments']) |
|
|
|
|
|
resample_dict = { |
|
'Day': 'D', |
|
'Three Days': '3D', |
|
'Week': 'W', |
|
'Two Weeks': '2W', |
|
'Month': 'M', |
|
'Quarter': 'Q', |
|
'Year': 'Y' |
|
} |
|
|
|
|
|
resample_time = st.selectbox('Select Time Resampling', list(resample_dict.keys())) |
|
|
|
df_filtered = df.set_index('Post Created') |
|
|
|
|
|
min_date = df_filtered.index.min().date() |
|
max_date = df_filtered.index.max().date() |
|
|
|
date_range = st.slider('Select Date Range', min_value=min_date, max_value=max_date, value=(min_date, max_date)) |
|
|
|
|
|
df_filtered = df_filtered[(df_filtered.index.date >= date_range[0]) & (df_filtered.index.date <= date_range[1])] |
|
|
|
|
|
df_resampled = df_filtered[variable].resample(resample_dict[resample_time]).sum() |
|
st.line_chart(df_resampled) |
|
|
|
st.markdown("### Correlation Analysis") |
|
|
|
options = ['Followers at Posting', 'Total Interactions', 'Likes', 'Comments'] |
|
scatter_variable_1 = st.selectbox('Select Variable 1 for Scatter Plot', options) |
|
|
|
scatter_variable_2 = st.selectbox('Select Variable 2 for Scatter Plot', options) |
|
|
|
|
|
st.write(f"Scatter Plot of {scatter_variable_1} vs {scatter_variable_2}") |
|
|
|
scatter_fig = px.scatter(df_filtered, x=scatter_variable_1, y=scatter_variable_2) |
|
|
|
st.plotly_chart(scatter_fig) |
|
|
|
|
|
corr = df_filtered[scatter_variable_1].corr(df_filtered[scatter_variable_2]) |
|
if corr > 0.7: |
|
st.write(f"The correlation coefficient is {corr}, indicating a strong positive relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > 0.3: |
|
st.write(f"The correlation coefficient is {corr}, indicating a moderate positive relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > -0.3: |
|
st.write(f"The correlation coefficient is {corr}, indicating a weak or no relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
elif corr > -0.7: |
|
st.write(f"The correlation coefficient is {corr}, indicating a moderate negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |
|
else: |
|
st.write(f"The correlation coefficient is {corr}, indicating a strong negative relationship between {scatter_variable_1} and {scatter_variable_2}.") |