File size: 7,207 Bytes
9541eae
 
ae1692d
9541eae
 
 
 
 
 
 
 
 
 
 
ae1692d
9541eae
ae1692d
9541eae
 
 
 
0e6d7eb
9541eae
 
0e6d7eb
9541eae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae1692d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9541eae
ae1692d
 
9541eae
 
ae1692d
9541eae
 
 
3bfabca
 
 
 
ae1692d
9541eae
 
 
 
ae1692d
9541eae
 
ae1692d
9541eae
0e6d7eb
9541eae
 
 
 
 
 
 
 
 
 
 
 
0e6d7eb
9541eae
 
 
 
 
 
 
 
 
 
ae1692d
9541eae
 
 
 
 
 
 
 
ae1692d
9541eae
 
 
 
 
 
 
 
 
 
 
 
 
ae1692d
9541eae
 
 
 
 
ae1692d
9541eae
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import pandas as pd
from huggingface_hub.hf_api import create_repo, upload_folder, upload_file, HfApi
from huggingface_hub.repository import Repository
import subprocess
import os
import tempfile
from uuid import uuid4
import pickle
import sweetviz as sv
import dabl
import re


def analyze_datasets(dataset, dataset_name, token, column=None, pairwise="off"):
    df = pd.read_csv(dataset.name)
    username = HfApi().whoami(token=token)["name"]
    if column is not None:
        analyze_report = sv.analyze(df, target_feat=column, pairwise_analysis=pairwise)
    else:
        analyze_report = sv.analyze(df, pairwise_analysis=pairwise)
        analyze_report.show_html('./index.html', open_browser=False)
    repo_url = create_repo(f"{username}/{dataset_name}", repo_type = "space", token = token, space_sdk = "static", private=False)
    
    upload_file(path_or_fileobj ="./index.html", path_in_repo = "./", repo_id =f"{username}/{dataset_name}", repo_type = "space", token=token)
    readme = f"---\ntitle: {dataset_name}\nemoji: ✨\ncolorFrom: green\ncolorTo: red\nsdk: static\npinned: false\ntags:\n- dataset-report\n---"    
    with open("README.md", "w+") as f:
        f.write(readme)
    upload_file(path_or_fileobj ="./README.md", path_in_repo = "README.md", repo_id =f"{username}/{dataset_name}", repo_type = "space", token=token)

    return f"Your dataset report will be ready at {repo_url}"


from sklearn.utils import estimator_html_repr


def extract_estimator_config(model):
    hyperparameter_dict = model.get_params(deep=True)
    table = "| Hyperparameters | Value |\n| :-- | :-- |\n"
    for hyperparameter, value in hyperparameter_dict.items():
        table += f"| {hyperparameter} | {value} |\n"
    return table

def detect_training(df, column):
    if dabl.detect_types(df)["continuous"][column] or dabl.detect_types(df)["dirty_float"][column]:
        trainer = dabl.SimpleRegressor()
    elif dabl.detect_types(df)["categorical"][column] or dabl.detect_types(df)["low_card_int"][column] or dabl.detect_types(df)["free_string"][column]:
        trainer = dabl.SimpleClassifier()
    return trainer

def edit_types(df):
    types = dabl.detect_types(df)
    low_cardinality = types[types["low_card_int"] == True].index.tolist()
    dirty_float = types[types["dirty_float"] == True].index.tolist()
    type_hints = {}
    for col in low_cardinality:
        type_hints[col] = "categorical"
    for col in dirty_float:
        type_hints[col] = "continuous"
    df_clean = dabl.clean(df, type_hints=type_hints)
    return df_clean

def train_baseline(dataset, dataset_name, token, column):
    df = pd.read_csv(dataset.name)
    df_clean = edit_types(df)
    fc = detect_training(df_clean, column)
    X = df_clean.drop(column, axis = 1)
    y = df_clean[column]
    
    with tempfile.TemporaryDirectory() as tmpdirname:
        from contextlib import redirect_stdout

        with open(f'{tmpdirname}/logs.txt', 'w') as f:
            with redirect_stdout(f):
                print('Logging training')
                fc.fit(X, y)
        username = HfApi().whoami(token=token)["name"]
        repo_url = create_repo(repo_id = f"{username}/{dataset_name}", token = token)
        
        readme = f"---\nlicense: apache-2.0\nlibrary_name: sklearn\n---\n\n"
        readme += f"## Baseline Model trained on {dataset_name} to predict {column}\n\n" 
        readme+="**Metrics of the best model:**\n\n"
        for elem in str(fc.current_best_).split("\n"):
            readme+= f"{elem}\n\n"
        readme+= "\n\n**See model plot below:**\n\n"
        readme+= re.sub(r"\n\s+", "", str(estimator_html_repr(fc.est_)))
        readme+= "\n\nThis model is trained with dabl library as a baseline, for better results, use [AutoTrain](https://huggingface.co/autotrain).\n\n"
        with open(f"{tmpdirname}/README.md", "w+") as f:
            f.write(readme)
        with open(f"{tmpdirname}/clf.pkl", mode="bw") as f:
            pickle.dump(fc, file=f)
        upload_folder(repo_id =f"{username}/{dataset_name}", folder_path=tmpdirname, repo_type = "model", token=token, path_in_repo="./")

    return f"Your model will be ready at {repo_url}"



with gr.Blocks() as demo:
    main_title = gr.Markdown("""# Baseline Trainer πŸͺ„πŸŒŸβœ¨""")
    main_desc = gr.Markdown("""This app trains a baseline model for a given dataset and pushes it to your Hugging Face Hub Profile with a model card. For better results, use [AutoTrain](https://huggingface.co/autotrain).""")
    
    
    with gr.Tabs():
        with gr.TabItem("Baseline Trainer") as baseline_trainer:
            with gr.Row():
                with gr.Column():
                    title = gr.Markdown(""" ## Train a supervised baseline model""")
                    description = gr.Markdown("This app trains a model and pushes it to your Hugging Face Hub Profile.")
                    dataset = gr.File(label = "Dataset")
                    column = gr.Text(label = "Enter target variable:")
                    pushing_desc = gr.Markdown("This app needs your Hugging Face Hub token and a unique name for your dataset report.")
                    dataset_name = gr.Text(label = "Enter dataset name:")
                    token = gr.Textbox(label = "Your Hugging Face Token")
                    inference_run = gr.Button("Train")
                    inference_progress = gr.StatusTracker(cover_container=True)

                outcome = gr.outputs.Textbox(label = "Progress")
                inference_run.click(
                    train_baseline,
                    inputs=[dataset, dataset_name, token, column],
                    outputs=outcome,
                    status_tracker=inference_progress,
                )
        with gr.TabItem("Analyze") as analyze:
            with gr.Row():
                with gr.Column():
                    title = gr.Markdown(""" ## Analyze Dataset """)
                    description = gr.Markdown("Analyze a dataset or predictive variables against a target variable in a dataset (enter a column name to column section if you want to compare against target value). You can also do pairwise analysis, but it has quadratic complexity.")
                    dataset = gr.File(label = "Dataset")
                    column = gr.Text(label = "Compare dataset against a target variable (Optional)")
                    pairwise = gr.Radio(["off", "on"], label = "Enable pairwise analysis")
                    token = gr.Textbox(label = "Your Hugging Face Token")
                    dataset_name = gr.Textbox(label = "Dataset Name")
                    pushing_desc = gr.Markdown("This app needs your Hugging Face Hub token and a unique repository name for your dataset report.")
                    inference_run = gr.Button("Infer")
                    inference_progress = gr.StatusTracker(cover_container=True)
                outcome = gr.outputs.Textbox()
                inference_run.click(
                    analyze_datasets,
                    inputs=[dataset, dataset_name, token, column, pairwise],
                    outputs=outcome,
                    status_tracker=inference_progress,
                )

demo.launch(debug=True)