File size: 3,936 Bytes
84b07f8
1
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: bokeh_plot"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio bokeh>=3.0 xyzservices"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import xyzservices.providers as xyz\n", "from bokeh.plotting import figure\n", "from bokeh.tile_providers import get_provider\n", "from bokeh.models import ColumnDataSource, Whisker\n", "from bokeh.plotting import figure\n", "from bokeh.sampledata.autompg2 import autompg2 as df\n", "from bokeh.sampledata.penguins import data\n", "from bokeh.transform import factor_cmap, jitter, factor_mark\n", "\n", "\n", "def get_plot(plot_type):\n", "    if plot_type == \"map\":\n", "        tile_provider = get_provider(xyz.OpenStreetMap.Mapnik)\n", "        plot = figure(\n", "            x_range=(-2000000, 6000000),\n", "            y_range=(-1000000, 7000000),\n", "            x_axis_type=\"mercator\",\n", "            y_axis_type=\"mercator\",\n", "        )\n", "        plot.add_tile(tile_provider)\n", "        return plot\n", "    elif plot_type == \"whisker\":\n", "        classes = list(sorted(df[\"class\"].unique()))\n", "\n", "        p = figure(\n", "            height=400,\n", "            x_range=classes,\n", "            background_fill_color=\"#efefef\",\n", "            title=\"Car class vs HWY mpg with quintile ranges\",\n", "        )\n", "        p.xgrid.grid_line_color = None\n", "\n", "        g = df.groupby(\"class\")\n", "        upper = g.hwy.quantile(0.80)\n", "        lower = g.hwy.quantile(0.20)\n", "        source = ColumnDataSource(data=dict(base=classes, upper=upper, lower=lower))\n", "\n", "        error = Whisker(\n", "            base=\"base\",\n", "            upper=\"upper\",\n", "            lower=\"lower\",\n", "            source=source,\n", "            level=\"annotation\",\n", "            line_width=2,\n", "        )\n", "        error.upper_head.size = 20\n", "        error.lower_head.size = 20\n", "        p.add_layout(error)\n", "\n", "        p.circle(\n", "            jitter(\"class\", 0.3, range=p.x_range),\n", "            \"hwy\",\n", "            source=df,\n", "            alpha=0.5,\n", "            size=13,\n", "            line_color=\"white\",\n", "            color=factor_cmap(\"class\", \"Light6\", classes),\n", "        )\n", "        return p\n", "    elif plot_type == \"scatter\":\n", "\n", "        SPECIES = sorted(data.species.unique())\n", "        MARKERS = [\"hex\", \"circle_x\", \"triangle\"]\n", "\n", "        p = figure(title=\"Penguin size\", background_fill_color=\"#fafafa\")\n", "        p.xaxis.axis_label = \"Flipper Length (mm)\"\n", "        p.yaxis.axis_label = \"Body Mass (g)\"\n", "\n", "        p.scatter(\n", "            \"flipper_length_mm\",\n", "            \"body_mass_g\",\n", "            source=data,\n", "            legend_group=\"species\",\n", "            fill_alpha=0.4,\n", "            size=12,\n", "            marker=factor_mark(\"species\", MARKERS, SPECIES),\n", "            color=factor_cmap(\"species\", \"Category10_3\", SPECIES),\n", "        )\n", "\n", "        p.legend.location = \"top_left\"\n", "        p.legend.title = \"Species\"\n", "        return p\n", "\n", "with gr.Blocks() as demo:\n", "    with gr.Row():\n", "        plot_type = gr.Radio(value=\"scatter\", choices=[\"scatter\", \"whisker\", \"map\"])\n", "        plot = gr.Plot()\n", "    plot_type.change(get_plot, inputs=[plot_type], outputs=[plot])\n", "    demo.load(get_plot, inputs=[plot_type], outputs=[plot])\n", "\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}