{ "cells": [ { "cell_type": "code", "execution_count": 4, "id": "1e0cd6a7", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0,'..')" ] }, { "cell_type": "code", "execution_count": 5, "id": "ba81c2ba", "metadata": {}, "outputs": [], "source": [ "from scripts.transformer_prediction_interface import TabPFNClassifier" ] }, { "cell_type": "code", "execution_count": 56, "id": "0fe8a920", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/Users/samuelmueller/TabPFN/TabPFN\r\n" ] } ], "source": [ "!pwd" ] }, { "cell_type": "code", "execution_count": 49, "id": "fd08a53d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching examples at: '/Users/samuelmueller/TabPFN/TabPFN/gradio_cached_examples/670/log.csv'\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/samuelmueller/opt/anaconda3/envs/TabPFN/lib/python3.7/site-packages/gradio/networking.py:59: ResourceWarning: unclosed \n", " s = socket.socket() # create a socket object\n", "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n", "/Users/samuelmueller/opt/anaconda3/envs/TabPFN/lib/python3.7/site-packages/gradio/networking.py:59: ResourceWarning: unclosed \n", " s = socket.socket() # create a socket object\n", "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7898/\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(, 'http://127.0.0.1:7898/', None)" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import gradio as gr\n", "import openml\n", "\n", "\n", "def compute(table: np.array):\n", " vfunc = np.vectorize(lambda s: len(s))\n", " non_empty_row_mask = (vfunc(table).sum(1) != 0)\n", " print(table)\n", " table = table[non_empty_row_mask]\n", " empty_mask = table == ''\n", " empty_inds = np.where(empty_mask)\n", " assert np.all(empty_inds[1][0] == empty_inds[1])\n", " y_column = empty_inds[1][0]\n", " eval_lines = empty_inds[0]\n", "\n", " train_table = np.delete(table, eval_lines, axis=0)\n", " eval_table = table[eval_lines]\n", "\n", " try:\n", " x_train = torch.tensor(np.delete(train_table, y_column, axis=1).astype(np.float32))\n", " x_eval = torch.tensor(np.delete(eval_table, y_column, axis=1).astype(np.float32))\n", "\n", " y_train = train_table[:, y_column]\n", " except ValueError:\n", " return \"Please only add numbers (to the inputs) or leave fields empty.\", None\n", "\n", " classifier = TabPFNClassifier(base_path='..', device='cpu')\n", " classifier.fit(x_train, y_train)\n", " y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)\n", " print(x_train, y_train, x_eval, y_eval)\n", "\n", " # print(file, type(file))\n", " out_table = table.copy().astype(str)\n", " out_table[eval_lines, y_column] = [f\"{y_e} (p={p_e:.2f})\" for y_e, p_e in zip(y_eval, p_eval)]\n", " return None, out_table\n", "\n", "\n", "def upload_file(file):\n", " if file.name.endswith('.arff'):\n", " dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)\n", " X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(\n", " dataset_format=\"array\"\n", " )\n", " return X_\n", " elif file.name.endswith('.csv') or file.name.endswith('.data'):\n", " df = pd.read_csv(file.name)\n", " return df.to_numpy()\n", "\n", "\n", "example = \\\n", " [\n", " [1, 2, 1],\n", " [2, 1, 1],\n", " [1, 1, 1],\n", " [2, 2, 2],\n", " [3, 4, 2],\n", " [3, 2, 2],\n", " [2, 3, '']\n", " ]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"This demo allows you to play with the **TabPFN**.\n", " You can either change the table manually (we have filled it with a toy benchmark, sum up to 3 has label 1 and over that label 2).\n", " The network predicts fields you leave empty. Only one column can have empty entries that are predicted.\n", " Please, provide everything but the label column as numeric values. It is ok to encode classes as integers.\n", " \"\"\")\n", " inp_table = gr.DataFrame(type='numpy', value=example, headers=[''] * 3)\n", " inp_file = gr.File(\n", " label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')\n", " btn = gr.Button(\"Predict Empty Table Cells\")\n", "\n", " inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)\n", "\n", " out_text = gr.Textbox()\n", " out_table = gr.DataFrame()\n", "\n", " btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table])\n", " examples = gr.Examples(examples=['./iris.csv'],\n", " inputs=[inp_file],\n", " outputs=[inp_table],\n", " fn=upload_file,\n", " cache_examples=True)\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": 52, "id": "c4510232", "metadata": {}, "outputs": [], "source": [ "df = pd.DataFrame({'hi':[1,2,'j']})" ] }, { "cell_type": "code", "execution_count": 59, "id": "2403f193", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[[1], [2], ['j']]" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ "sys:1: ResourceWarning: unclosed socket \n", "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n" ] } ], "source": [ "df.to_numpy().tolist()" ] }, { "cell_type": "code", "execution_count": null, "id": "adf1a91c", "metadata": {}, "outputs": [], "source": [ "k" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 5 }