{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Baseline human labels for ours vs. other methods, with 3-per-row voting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "from collections import defaultdict\n",
    "\n",
    "MAX_FILES=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(filename):\n",
    "    csvfile = open(filename)\n",
    "    reader = csv.reader(csvfile)\n",
    "\n",
    "    data = []\n",
    "    for i, row in enumerate(reader):\n",
    "        if i == 0:\n",
    "            headers = row\n",
    "        else:\n",
    "            data.append(row)\n",
    "    csvfile.close()\n",
    "    return headers, data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get stats\n",
    "\n",
    "Run these cells in order to:\n",
    "* get stats for ontopicness and fluency to copy/paste\n",
    "* save percents for each topic for plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## topics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# for topics\n",
    "def decode(st):\n",
    "    ints = [int(s) for s in st.split('_')]\n",
    "    # Version 2\n",
    "    ii, j1, j2 = ints[0], np.mod(ints[1], MAX_FILES), np.mod(ints[2], MAX_FILES)\n",
    "    return ii, j1, j2\n",
    "\n",
    "# p-value of two binomial distributions\n",
    "# one sided tail\n",
    "def two_samp(x1, x2, n1, n2):\n",
    "    p1 = x1/n1\n",
    "    p2 = x2/n2\n",
    "    phat = (x1 + x2) / (n1 + n2)\n",
    "    z = (p1 - p2) / np.sqrt(phat * (1-phat) * (1/n1 + 1/n2))\n",
    "    return stats.norm.sf(np.abs(z))\n",
    "\n",
    "def print_info_t(scores, counts, single_pvalue=True):\n",
    "    pvalues = np.zeros((MAX_FILES, MAX_FILES))\n",
    "    for i in range(MAX_FILES):\n",
    "        for j in range(i, MAX_FILES):\n",
    "            dist_i = [1] * scores[i] + [0] * (counts[i] - scores[i])\n",
    "            dist_j = [1] * scores[j] + [0] * (counts[j] - scores[j])\n",
    "            pvalue = two_samp(scores[i], scores[j], counts[i], counts[j])\n",
    "            pvalues[i, j] = pvalue\n",
    "            pvalues[j, i] = pvalue\n",
    "    percs = scores / counts\n",
    "\n",
    "    print('total counts, on topic counts, percentages:')\n",
    "    for i in range(MAX_FILES):\n",
    "        if i == 0 and single_pvalue and MAX_FILES == 2:\n",
    "            print('{},{},{},{}'.format(counts[i], scores[i], percs[i], pvalues[0][1]))\n",
    "        else:\n",
    "            print('{},{},{}'.format(counts[i], scores[i], percs[i]))\n",
    "\n",
    "    if not (single_pvalue and MAX_FILES == 2):\n",
    "        for row in pvalues:\n",
    "            print('{},{}'.format(row[0],row[1]))\n",
    "\n",
    "def get_counts_indices(data, order_index, label_indices):\n",
    "    scores = np.zeros(MAX_FILES, dtype=int)\n",
    "    counts = np.zeros(MAX_FILES, dtype=int)\n",
    "    skipped = 0\n",
    "    for rownum, row in enumerate(data):\n",
    "        order = row[order_index]\n",
    "        for label_index in label_indices:\n",
    "            label = row[label_index].lower()\n",
    "            if len(order) > 0 and len(label) > 0:\n",
    "                a_cat, b_cat = decode(order)[1:]\n",
    "                # print(label, order, a_cat, b_cat)\n",
    "                if label == 'a' or label == 'both':\n",
    "                    scores[a_cat] += 1\n",
    "                if label == 'b' or label == 'both':\n",
    "                    scores[b_cat] += 1\n",
    "                counts[a_cat] += 1\n",
    "                counts[b_cat] += 1\n",
    "                if label not in ['a', 'b', 'both', 'neither']:\n",
    "                    print('******invalid label: {}'.format(label))\n",
    "            else:\n",
    "                #print('empty label; skipping', rownum)\n",
    "                skipped += 1\n",
    "    print('skipped {}'.format(skipped))\n",
    "    print_info_t(scores, counts)\n",
    "    return scores, counts\n",
    "\n",
    "# vote by row. each row contributes to one count (and 0 or 1 score based on majority vote)\n",
    "def get_counts_vote_row(data, order_index, label_indices):\n",
    "    scores = np.zeros(MAX_FILES, dtype=int)\n",
    "    counts = np.zeros(MAX_FILES, dtype=int)\n",
    "    skipped = 0\n",
    "    for rownum, row in enumerate(data):\n",
    "        order = row[order_index]\n",
    "        if len(order) == 0:\n",
    "            skipped += 1\n",
    "        else:\n",
    "            a_cat, b_cat = decode(order)[1:]\n",
    "            row_score_a, row_score_b, row_counts = 0, 0, 0\n",
    "            for label_index in label_indices:\n",
    "                label = row[label_index].lower()\n",
    "                if len(label) > 0:\n",
    "                    if label == 'a' or label == 'both':\n",
    "                        row_score_a += 1\n",
    "                    if label == 'b' or label == 'both':\n",
    "                        row_score_b += 1\n",
    "                    row_counts += 1\n",
    "                    if label not in ['a', 'b', 'both', 'neither']:\n",
    "                        print('******invalid label: {}'.format(label))\n",
    "                else:\n",
    "                    print('empty label for nonempty prompt', rownum)\n",
    "            # update big points\n",
    "            if row_counts == 3:\n",
    "                scores[a_cat] += row_score_a // 2\n",
    "                scores[b_cat] += row_score_b // 2\n",
    "                counts[a_cat] += 1\n",
    "                counts[b_cat] += 1\n",
    "            else:\n",
    "                print('incomplete row...')\n",
    "    print('skipped {}'.format(skipped))\n",
    "    print_info_t(scores, counts)\n",
    "    return scores, counts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fluency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def print_info_f_lists(scorelist, single_pvalue=True):\n",
    "    for i in range(MAX_FILES):\n",
    "        if len(scorelist[i]) == 0:\n",
    "            print('skipping; no data')\n",
    "            return\n",
    "\n",
    "    pvalues = np.zeros((MAX_FILES, MAX_FILES))\n",
    "    for i in range(MAX_FILES):\n",
    "        for j in range(i, MAX_FILES):\n",
    "            pvalue = stats.ttest_ind(scorelist[i], scorelist[j]).pvalue\n",
    "            pvalues[i, j] = pvalue\n",
    "            pvalues[j, i] = pvalue\n",
    "\n",
    "    print('mean, stdev, min, max, counts:')\n",
    "    for i in range(MAX_FILES):\n",
    "        if i == 0 and single_pvalue and len(scorelist) == 2:\n",
    "            print('{},{},{},{},{},{}'.format(np.mean(scorelist[i]), np.std(scorelist[i]),\n",
    "                np.min(scorelist[i]), np.max(scorelist[i]), len(scorelist[i]), pvalues[0][1]))\n",
    "        else:\n",
    "            print('{},{},{},{},{}'.format(np.mean(scorelist[i]), np.std(scorelist[i]),\n",
    "                np.min(scorelist[i]), np.max(scorelist[i]), len(scorelist[i])))\n",
    "    if not (single_pvalue and len(scorelist) == 2):\n",
    "        print('p-values')\n",
    "        for row in pvalues:\n",
    "            print('{},{}'.format(row[0],row[1]))\n",
    "\n",
    "def get_fluencies_indices(data, order_index, label_indices):\n",
    "    scorelist = [[], []]\n",
    "    skipped = 0\n",
    "    for r, row in enumerate(data):\n",
    "        order = row[order_index]\n",
    "        if len(order) == 0:\n",
    "            continue\n",
    "        for label_ind_pair in label_indices:\n",
    "            #a_cat, b_cat = decode(order)[1:]\n",
    "            cats = decode(order)[1:]\n",
    "            for i, ind in enumerate(label_ind_pair):\n",
    "                label = row[ind]\n",
    "                if len(label) > 0:\n",
    "                    scorelist[cats[i]].append(int(label))\n",
    "                else:\n",
    "                    skipped += 1\n",
    "    print('skipped {}'.format(skipped))\n",
    "    print_info_f_lists(scorelist)\n",
    "    return scorelist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run on all files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# aggregated human labeled everything\n",
    "dirname = 'ctrl_wd_openai_csvs/'\n",
    "# comment out any of the below if you don't want to include them in \"all\"\n",
    "file_info = [\n",
    "    'ctrl_legal.csv',\n",
    "    'ctrl_politics.csv',\n",
    "    'ctrl_religion.csv',\n",
    "    'ctrl_science.csv',\n",
    "    'ctrl_technologies.csv',\n",
    "    'ctrl_positive.csv',\n",
    "    'ctrl_negative.csv',\n",
    "    'openai_positive.csv',\n",
    "    'greedy_legal.csv',\n",
    "    'greedy_military.csv',\n",
    "    'greedy_politics.csv',\n",
    "    'greedy_religion.csv',\n",
    "    'greedy_science.csv',\n",
    "    'greedy_space.csv',\n",
    "    'greedy_technologies.csv',\n",
    "    'greedy_positive.csv',\n",
    "    'greedy_negative.csv',\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ctrl_legal.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "20,7,0.35,0.24507648020791256\n",
      "20,5,0.25\n",
      "\n",
      "ctrl_politics.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "20,7,0.35,0.16864350736717681\n",
      "20,10,0.5\n",
      "\n",
      "ctrl_religion.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "20,12,0.6,0.000782701129001274\n",
      "20,20,1.0\n",
      "\n",
      "ctrl_science.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "20,15,0.75,0.012580379600204389\n",
      "20,8,0.4\n",
      "\n",
      "ctrl_technologies.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "20,15,0.75,0.005502076588434386\n",
      "20,7,0.35\n",
      "\n",
      "ctrl_positive.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "15,13,0.8666666666666667,0.312103057383203\n",
      "15,12,0.8\n",
      "\n",
      "ctrl_negative.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "15,8,0.5333333333333333,0.12785217497142026\n",
      "15,11,0.7333333333333333\n",
      "\n",
      "openai_positive.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "45,38,0.8444444444444444,7.502148606340828e-12\n",
      "45,6,0.13333333333333333\n",
      "\n",
      "greedy_legal.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,26,0.43333333333333335,0.014054020073575932\n",
      "60,38,0.6333333333333333\n",
      "\n",
      "greedy_military.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,21,0.35,0.423683196354148\n",
      "60,20,0.3333333333333333\n",
      "\n",
      "greedy_politics.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,20,0.3333333333333333,0.423683196354148\n",
      "60,21,0.35\n",
      "\n",
      "greedy_religion.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,31,0.5166666666666667,0.004543733726219588\n",
      "60,17,0.2833333333333333\n",
      "\n",
      "greedy_science.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,33,0.55,0.04996165925796605\n",
      "60,24,0.4\n",
      "\n",
      "greedy_space.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,34,0.5666666666666667,2.9438821372586324e-08\n",
      "60,6,0.1\n",
      "\n",
      "greedy_technologies.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "60,36,0.6,0.014229868458155282\n",
      "60,24,0.4\n",
      "\n",
      "greedy_positive.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "45,37,0.8222222222222222,6.07065790526639e-09\n",
      "45,10,0.2222222222222222\n",
      "\n",
      "greedy_negative.csv\n",
      "skipped 0\n",
      "total counts, on topic counts, percentages:\n",
      "45,18,0.4,0.0048164878862943334\n",
      "45,7,0.15555555555555556\n",
      "\n",
      "all:\n",
      "total counts, on topic counts, percentages:\n",
      "685,371,0.5416058394160584,5.6920836882984375e-12\n",
      "685,246,0.35912408759124087\n",
      "\n",
      "------------\n",
      "\n",
      "ctrl_legal.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.35,0.6538348415311009,2,5,60,0.21268659490448816\n",
      "3.183333333333333,0.7851043808875918,2,5,60\n",
      "\n",
      "ctrl_politics.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.6333333333333333,0.682316316348624,2,5,60,0.5620319695586566\n",
      "3.7,0.5567764362830021,2,5,60\n",
      "\n",
      "ctrl_religion.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.5833333333333335,0.7369230323144715,2,5,60,0.025496401986981814\n",
      "3.8666666666666667,0.6182412330330469,2,5,60\n",
      "\n",
      "ctrl_science.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.9166666666666665,0.7139483330201298,2,5,60,0.11926537531844811\n",
      "3.7333333333333334,0.5436502143433364,3,5,60\n",
      "\n",
      "ctrl_technologies.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.566666666666667,0.8239471396205517,2,5,60,0.41405751072305697\n",
      "3.683333333333333,0.7186020379103366,1,5,60\n",
      "\n",
      "ctrl_positive.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.7777777777777777,0.5921294486432991,2,5,45,0.2770324945551848\n",
      "3.911111111111111,0.5506449641495051,3,5,45\n",
      "\n",
      "ctrl_negative.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "2.933333333333333,0.7999999999999999,1,4,45,0.15456038547144507\n",
      "3.1777777777777776,0.7969076034240491,1,4,45\n",
      "\n",
      "openai_positive.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.6814814814814816,0.83134742794656,2,5,135,0.00044715078341087973\n",
      "3.3185185185185184,0.8402103074636584,1,5,135\n",
      "\n",
      "greedy_legal.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.861111111111111,0.6033599339337534,2,5,180,2.0680624898872873e-09\n",
      "3.3722222222222222,0.8757888859990781,1,5,180\n",
      "\n",
      "greedy_military.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.988888888888889,0.7148340047318592,1,5,180,2.9929380752302575e-05\n",
      "3.6222222222222222,0.9138171197756484,1,5,180\n",
      "\n",
      "greedy_politics.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.8222222222222224,0.684393937530422,2,5,180,0.0007209758186600587\n",
      "3.522222222222222,0.957169180190301,1,5,180\n",
      "\n",
      "greedy_religion.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.8333333333333335,0.8975274678557507,1,5,180,4.066885786996924e-09\n",
      "3.2111111111111112,1.0487499908032782,1,5,180\n",
      "\n",
      "greedy_science.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.8777777777777778,0.5836242660741733,2,5,180,0.0006552437647639663\n",
      "3.6166666666666667,0.8318319808978519,1,5,180\n",
      "\n",
      "greedy_space.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.716666666666667,0.8251262529657709,1,5,180,0.0954991700009854\n",
      "3.577777777777778,0.7450246495217872,1,5,180\n",
      "\n",
      "greedy_technologies.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "4.011111111111111,0.5476098457934048,2,5,180,6.183501355237109e-11\n",
      "3.4555555555555557,0.9563949801335698,1,5,180\n",
      "\n",
      "greedy_positive.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.740740740740741,0.7299209796192601,1,5,135,0.7085851710838819\n",
      "3.7777777777777777,0.8833158628600795,1,5,135\n",
      "\n",
      "greedy_negative.csv\n",
      "skipped 0\n",
      "mean, stdev, min, max, counts:\n",
      "3.762962962962963,0.5863560719159496,2,5,135,0.02527504830979518\n",
      "3.5555555555555554,0.8916623398995057,1,5,135\n",
      "\n",
      "all:\n",
      "mean, stdev, min, max, counts:\n",
      "3.78345498783455,0.735818880968324,1,5,2055,6.013335996010963e-25\n",
      "3.5206812652068127,0.8798559510028829,1,5,2055\n",
      "total counts\n",
      "2055\n",
      "2055\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/rosanne/anaconda3/envs/py36/lib/python3.6/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "# hardcoded indices\n",
    "category_index = -1 # index of encoded seed and methods\n",
    "topic_indices = [2, 6, 10]\n",
    "fluency_indices = [(3,4), (7,8), (11,12)]\n",
    "\n",
    "all_scores = np.zeros(MAX_FILES, dtype=int)\n",
    "all_counts = np.zeros(MAX_FILES, dtype=int)\n",
    "percs_ordered = np.zeros((len(file_info), MAX_FILES)) # percents saved in same order as file names\n",
    "for i, fname in enumerate(file_info):\n",
    "    filename = dirname + fname\n",
    "    headers, data = get_data(filename)\n",
    "    print(fname)\n",
    "    scores, counts = get_counts_vote_row(data, category_index, topic_indices)\n",
    "    all_scores += scores\n",
    "    all_counts += counts\n",
    "    percs_ordered[i] = 100 * scores / counts\n",
    "    print()\n",
    "print('all:')\n",
    "print_info_t(all_scores, all_counts)\n",
    "print('\\n------------\\n')\n",
    "\n",
    "# uber labeled fluencies\n",
    "all_fluencies = [[], []]\n",
    "for fname in file_info:\n",
    "    filename = dirname + fname\n",
    "    headers, data = get_data(filename)\n",
    "    print(fname)\n",
    "    new_scores = get_fluencies_indices(data, category_index, fluency_indices)\n",
    "    for i in range(len(all_fluencies)):\n",
    "        all_fluencies[i].extend(new_scores[i])\n",
    "    print()\n",
    "print('all:')\n",
    "print_info_f_lists(all_fluencies)\n",
    "print('total counts')\n",
    "\n",
    "for x in all_fluencies:\n",
    "    print(len(x))\n",
    "    \n",
    "all_scores_hist = all_fluencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}