File size: 4,060 Bytes
ad6c924
 
 
 
 
 
 
 
 
79de90d
 
ad6c924
 
 
 
 
 
 
 
 
 
0f36fb3
bac7e94
ad6c924
bac7e94
ad6c924
 
bac7e94
ad6c924
 
 
 
 
 
0f36fb3
bac7e94
ad6c924
 
 
 
79de90d
ad6c924
 
 
 
 
 
 
 
 
 
 
bac7e94
79de90d
bac7e94
79de90d
bac7e94
 
 
 
 
 
 
 
 
 
 
 
 
 
79de90d
 
bac7e94
79de90d
bac7e94
79de90d
ad6c924
 
 
 
 
bac7e94
ad6c924
 
 
 
bac7e94
 
 
 
ad6c924
bac7e94
ad6c924
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
## LIBRARIES ###
## Data
import pandas as pd
pd.options.display.float_format = '${:,.2f}'.format

# Analysis

# App & Visualization
import streamlit as st
from bokeh.models import CustomJS, ColumnDataSource, TextInput, DataTable, TableColumn
from bokeh.plotting import figure
from bokeh.transform import factor_cmap
from bokeh.palettes import Category20c_20
from bokeh.layouts import column, row

# utils

def datasets_explorer_viz(df):
    s = ColumnDataSource(df)
    TOOLTIPS= [("dataset_id", "@dataset_id"), ("task", "@task")]
    color = factor_cmap('task', palette=Category20c_20, factors=df['task'].unique()) 
    p = figure(plot_width=1000, plot_height=800, tools="hover,wheel_zoom,pan,box_select", tooltips=TOOLTIPS, toolbar_location="above")
    p.scatter('x', 'y', size=5, source=s, alpha=0.8,marker='circle',fill_color = color, line_color=color, legend_field = 'task')
    p.legend.location = "bottom_right"
    p.legend.click_policy="mute"
    p.legend.label_text_font_size="8pt"
    table_source = ColumnDataSource(data=dict())
    selection_source = ColumnDataSource(data=dict())
    columns = [
        # TableColumn(field="x", title="X data"),
        # TableColumn(field="y", title="Y data"),
        TableColumn(field="task", title="Task"),
        TableColumn(field="dataset_id", title="Dataset ID"),
    ]
    data_table = DataTable(source=table_source, columns=columns, width=500)
    p.circle('x', 'y',source=selection_source, size=5, color= 'red')
    s.selected.js_on_change('indices', CustomJS(args=dict(umap_source=s, table_source=table_source), code="""
            const inds = cb_obj.indices;
            const tableData = table_source.data;
            const umapData = umap_source.data;

            tableData['task'] = []
            tableData['dataset_id'] = []

            for (let i = 0; i < inds.length; i++) {
                tableData['task'].push(umapData['task'][inds[i]])
                tableData['dataset_id'].push(umapData['dataset_id'][inds[i]])
            }
            table_source.data = tableData;
            table_source.change.emit();
    """
    ))
    text_input = TextInput(value="", title="Search")
    #text_input.on_change("value_input", 
    text_input.js_on_change('value', CustomJS(args=dict(plot_source=s, selection_source=selection_source), code="""
        const plot_data = plot_source.data;
        const selectData = selection_source.data
        const value = cb_obj.value

        selectData['x'] = []
        selectData['y'] = []
        selectData['dataset_id'] = []
        selectData['task'] = []

        for (var i = 0; i < plot_data['dataset_id'].length; i++) {
            if (plot_data['dataset_id'][i].includes(value) || plot_data['task'][i].includes(value)) {
                selectData['x'].push(plot_data['x'][i])
                selectData['y'].push(plot_data['y'][i])
                selectData['dataset_id'].push(plot_data['dataset_id'][i])
                selectData['task'].push(plot_data['task'][i])
            }
        }
        selection_source.change.emit()
    """))
    
    st.bokeh_chart(row(column(text_input,p), data_table))


if __name__ == "__main__":
    ### STREAMLIT APP CONGFIG ###
    st.set_page_config(layout="wide", page_title="Datasets Explorer")
    st.title('Interactive Datasets Explorer')
    #lcol, rcol = st.columns([2, 2])
    # ******* loading the mode and the data

    ### LOAD DATA AND SESSION VARIABLES ###
    with st.expander("How to interact with the plot:"):
        st.markdown("* Each point in the plot represents a HF hub dataset categorized by their `task_id`.")
        st.markdown("* Every dataset is emebdded using the [SPECTER](https://github.com/allenai/specter#advanced-training-your-own-model) embedding of its corresponding paper abstract.")
        st.markdown("* You can either search for a dataset or drag and select to peek into the cluster content.")
    datasets_df = pd.read_parquet('./assets/data/datasets_df.parquet')
    st.warning("Hugging Face 🤗 Datasets Explorer")
    datasets_explorer_viz(datasets_df)