{ "cells": [ { "cell_type": "markdown", "id": "3bae1d7d-a2be-444d-97cc-d1cbf8843bf1", "metadata": {}, "source": [ "# Invisible RAG Pilot Demo App" ] }, { "cell_type": "code", "execution_count": null, "id": "2a8e18f7-cc88-4bbf-a6e1-095237ed7714", "metadata": {}, "outputs": [], "source": [ "import json\n", "import gspread\n", "import gradio as gr\n", "\n", "\n", "class RAGInterface:\n", " \"\"\"\n", " Setup the gradio app for loading/saving/syncronizing the mockup A/B evaluation RAG tasks.\n", " The app is deployed on Hugging Face spaces at https://huggingface.co/spaces/sukiboo/invisible-rag-demo\n", " \"\"\"\n", "\n", " def __init__(self):\n", " self.setup_interface()\n", " self.launch_interface()\n", "\n", " def setup_interface(self):\n", " \"\"\"Configure the A/B Evaluation RAG task interface.\"\"\"\n", " with gr.Blocks(title='Demo AB Evaluate RAG') as self.interface:\n", "\n", " # protected fields\n", " _task_id = gr.Textbox(label='Task ID', interactive=False, visible=False)\n", "\n", " # task id and load/save/reset buttons\n", " with gr.Row():\n", " task_id = gr.Textbox(container=False, placeholder='Enter a task ID: 1--11', scale=9)\n", " load_button = gr.Button('Load Task', scale=1)\n", " save_button = gr.Button('Save Task', scale=1, variant='primary')\n", " reset_button = gr.Button('Reset Task', scale=1, variant='stop')\n", "\n", " # chat history and search results\n", " chat = gr.Chatbot(height=700, layout='bubble', bubble_full_width=False, label='Chat History')\n", " sources = gr.Markdown()\n", "\n", " # model completions for answers 1 and 2\n", " with gr.Row():\n", " with gr.Column():\n", " answer1 = gr.Textbox(label='Answer 1', max_lines=50)\n", " with gr.Column():\n", " answer2 = gr.Textbox(label='Answer 2', max_lines=50)\n", "\n", " # individual ratings for answers 1 and 2\n", " with gr.Row():\n", " with gr.Column():\n", " groundedness1 = gr.Radio(label='Groundedness', choices=['Bad', 'Good', 'Perfect'])\n", " fluency1 = gr.Radio(label='Fluency', choices=['Bad', 'Good', 'Perfect'])\n", " utility1 = gr.Radio(label='Utility', choices=['Catastrophic', 'Bad', 'Good', 'Perfect'])\n", " notes1 = gr.Textbox(label='Notes', placeholder='N/A')\n", " with gr.Column():\n", " groundedness2 = gr.Radio(label='Groundedness', choices=['Bad', 'Good', 'Perfect'])\n", " fluency2 = gr.Radio(label='Fluency', choices=['Bad', 'Good', 'Perfect'])\n", " utility2 = gr.Radio(label='Utility', choices=['Catastrophic', 'Bad', 'Good', 'Perfect'])\n", " notes2 = gr.Textbox(label='Notes', placeholder='N/A')\n", "\n", " # overall rating\n", " overall = gr.Radio(label='Overall Rating', choices=['#1 Better', 'Equally Bad', 'Equally Good', '#2 Better'])\n", " notes = gr.Textbox(label='Notes', placeholder='A brief justification for the overall rating')\n", "\n", " # input/output fields\n", " answers = (answer1, answer2)\n", " ratings1 = (groundedness1, fluency1, utility1, notes1)\n", " ratings2 = (groundedness2, fluency2, utility2, notes2)\n", " ratings = (*ratings1, *ratings2, overall, notes)\n", "\n", " # button clicks\n", " load_button.click(self.load_task, inputs=[task_id], outputs=[_task_id, chat, sources, *answers, *ratings])\n", " save_button.click(self.save_task, inputs=[_task_id, *ratings], outputs=None)\n", " reset_button.click(self.reset_task, inputs=[_task_id], outputs=[*ratings])\n", "\n", " def load_task(self, task_id):\n", " \"\"\"Load the task and parse the info.\"\"\"\n", " task = self.read_task(task_id)\n", " try:\n", " id = task['id']\n", " chat = task['chat_history'] + [[task['question'], task['search_query']]]\n", " answers = [task['answer_1'], task['answer_2']]\n", " sources = self.load_sources(task)\n", " ratings = self.load_ratings(task)\n", " gr.Info(f'Task demo_task_{task_id} is loaded!')\n", " return id, chat, sources, *answers, *ratings\n", " except:\n", " raise gr.Error(f'Could not load the task demo_task_{task_id} :(')\n", "\n", " def read_task(self, task_id):\n", " \"\"\"Read the json task file.\"\"\"\n", " try:\n", " with open(f'./data/demo_task_{task_id}.json') as task_file:\n", " task = json.load(task_file)\n", " return task\n", " except FileNotFoundError:\n", " raise gr.Error(f'Task demo_task_{task_id} is not found :(')\n", "\n", " def load_sources(self, task):\n", " \"\"\"Parse the search results.\"\"\"\n", " sources = []\n", " for idx, source in enumerate(task['search_results']):\n", " sources.append(f'##### {idx+1}. {source.replace(\"<\", f\"{chr(92)}<\")}\\n')\n", " return '\\n---\\n'.join(['## Search Results'] + sources + ['']) if sources else ''\n", "\n", " def load_ratings(self, task):\n", " \"\"\"Parse the ratings for each answer.\"\"\"\n", " # load ratings for answer 1\n", " ratings1 = (task['ratings_1']['groundedness'],\n", " task['ratings_1']['fluency'],\n", " task['ratings_1']['utility'],\n", " task['ratings_1']['notes'])\n", " # load ratings for answer 2\n", " ratings2 = (task['ratings_2']['groundedness'],\n", " task['ratings_2']['fluency'],\n", " task['ratings_2']['utility'],\n", " task['ratings_2']['notes'])\n", " # load overall ratings\n", " overall = task['overall']\n", " notes = task['notes']\n", " return (*ratings1, *ratings2, overall, notes)\n", "\n", " def save_task(self, task_id, *ratings):\n", " \"\"\"Save the task into a new json file.\"\"\"\n", " # load the original task\n", " with open(f'./data/demo_task_{task_id}.json') as task_file:\n", " task = json.load(task_file)\n", " # parse the ratings\n", " groundedness1, fluency1, utility1, notes1, \\\n", " groundedness2, fluency2, utility2, notes2, \\\n", " overall, notes = ratings\n", " # update the ratings for answer 1\n", " task['ratings_1']['groundedness'] = groundedness1\n", " task['ratings_1']['fluency'] = fluency1\n", " task['ratings_1']['utility'] = utility1\n", " task['ratings_1']['notes'] = notes1\n", " # update the ratings for answer 2\n", " task['ratings_2']['groundedness'] = groundedness2\n", " task['ratings_2']['fluency'] = fluency2\n", " task['ratings_2']['utility'] = utility2\n", " task['ratings_2']['notes'] = notes2\n", " # update overall ratings\n", " task['overall'] = overall\n", " task['notes'] = notes\n", " try:\n", " # save the task to json file\n", " with open(f'./data/demo_task_{task_id}.json', 'w', encoding='utf-8') as task_file:\n", " json.dump(task, task_file, ensure_ascii=False, indent=4)\n", " # save the task to google spreadsheet\n", " self.save_gsheet(task_id, ratings)\n", " gr.Info(f'Task demo_task_{task_id} is saved!')\n", " except:\n", " raise gr.Error(f'Could not save the task demo_task_{task_id} :(')\n", "\n", " def reset_task(self, task_id):\n", " \"\"\"Reset the task by erasing the ratings and operator notes.\"\"\"\n", " # load the original task\n", " with open(f'./data/demo_task_{task_id}.json') as task_file:\n", " task = json.load(task_file)\n", " # erase the ratings for answer 1\n", " task['ratings_1']['groundedness'] = ''\n", " task['ratings_1']['fluency'] = ''\n", " task['ratings_1']['utility'] = ''\n", " task['ratings_1']['notes'] = ''\n", " # erase the ratings for answer 2\n", " task['ratings_2']['groundedness'] = ''\n", " task['ratings_2']['fluency'] = ''\n", " task['ratings_2']['utility'] = ''\n", " task['ratings_2']['notes'] = ''\n", " # erase overall ratings\n", " task['overall'] = ''\n", " task['notes'] = ''\n", " try:\n", " # save the reset task to json file\n", " with open(f'./data/demo_task_{task_id}.json', 'w', encoding='utf-8') as task_file:\n", " json.dump(task, task_file, ensure_ascii=False, indent=4)\n", " # save the reset task to google spreadsheet\n", " self.reset_gsheet(task_id)\n", " gr.Info(f'Task demo_task_{task_id} is reset!')\n", " except:\n", " raise gr.Error(f'Could not reset the task demo_task_{task_id} :(')\n", " return '', '', '', '', '', '', '', '', '', ''\n", "\n", " def save_gsheet(self, id, ratings):\n", " \"\"\"Save the task to google spreadsheet.\"\"\"\n", " # parse the ratings\n", " groundedness1, fluency1, utility1, notes1, \\\n", " groundedness2, fluency2, utility2, notes2, \\\n", " overall, notes = ratings\n", " try:\n", " # configure gsheet credentials\n", " gc = gspread.service_account('./gsheet_service_account.json')\n", " sheet_id = '1D2sfE9YXKtd7cKlgalo5UnuNKC-GhxlGqHVYUlkQlCY'\n", " sh = gc.open_by_key(sheet_id).worksheet('demo-app')\n", " # update task ratings in the worksheet\n", " sh.update(range_name=f'C{3+int(id)}:J{3+int(id)}',\n", " values=[[groundedness1, fluency1, utility1, groundedness2, fluency2, utility2, overall, notes]])\n", " except:\n", " gr.Warning(f'Could not save the task demo_task_{task_id} to the spreadsheet :(')\n", "\n", " def reset_gsheet(self, id):\n", " \"\"\"Reset the task ratings in google spreadsheet.\"\"\"\n", " try:\n", " # configure gsheet credentials\n", " gc = gspread.service_account('./gsheet_service_account.json')\n", " sheet_id = '1D2sfE9YXKtd7cKlgalo5UnuNKC-GhxlGqHVYUlkQlCY'\n", " sh = gc.open_by_key(sheet_id).worksheet('demo-app')\n", " # update task ratings in the worksheet\n", " sh.batch_clear([f'C{3+int(id)}:J{3+int(id)}'])\n", " except:\n", " gr.Warning(f'Could not reset the task demo_task_{task_id} in the spreadsheet :(')\n", "\n", " def launch_interface(self):\n", " \"\"\"Launch the A/B Evaluation RAG task interface.\"\"\"\n", " gr.close_all()\n", " self.interface.queue(default_concurrency_limit=None)\n", " self.interface.launch()\n", "\n", "\n", "rag = RAGInterface()" ] }, { "cell_type": "code", "execution_count": null, "id": "ade1097d-35ce-4f7a-a689-1b51973cbc70", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "6707866e-8f1b-4bda-9b12-0008e289ab77", "metadata": {}, "outputs": [], "source": [ "# create placeholder tasks\n", "import os\n", "import json\n", "\n", "os.makedirs('./data/', exist_ok=True)\n", "for idx in range(1):\n", " task = {\n", " 'id': f'{idx}',\n", " 'chat_history': [['user message 1', 'bot message 1'], ['user message 2', 'bot message 2']],\n", " 'question': 'question',\n", " 'search_query': 'search query',\n", " 'search_results': ['source 1', 'source 2', 'source 3'],\n", " 'answer_1': 'answer 1',\n", " 'answer_2': 'answer 2',\n", " 'ratings_1': {'groundedness': '', 'utility': '', 'fluency': '', 'notes': ''},\n", " 'ratings_2': {'groundedness': '', 'utility': '', 'fluency': '', 'notes': ''},\n", " 'overall': '',\n", " 'notes': ''\n", " }\n", " with open(f'./data/demo_task_{idx}.json', 'w', encoding='utf-8') as task_file:\n", " json.dump(task, task_file, ensure_ascii=False, indent=4)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e99c2d79-d544-4d30-ab22-6452385d3593", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "98f44682-2088-4925-8e2a-563197a50b66", "metadata": {}, "outputs": [], "source": [ "# make demo tasks from the csv of the spreadsheet\n", "# https://docs.google.com/spreadsheets/d/1kYW0cABv2C-mMmmw2Uc50mQC0MmOuoqKJQaBp7IyCho/edit#gid=1934745276\n", "import os\n", "import json\n", "import pandas as pd\n", "\n", "df = pd.read_csv('./dev.csv')\n", "df\n", "\n", "os.makedirs('./data/', exist_ok=True)\n", "for idx in range(len(df)):\n", " row = df.iloc[idx]\n", " task = {\n", " 'id': f'{idx+1}',\n", " 'chat_history': [],\n", " 'question': f'{row[\"question\"]}',\n", " 'search_query': '',\n", " 'search_results': [],\n", " 'answer_1': f'{row[\"answer_1\"]}',\n", " 'answer_2': f'{row[\"answer_2\"]}',\n", " 'ratings_1': {'groundedness': '', 'utility': '', 'fluency': '', 'notes': ''},\n", " 'ratings_2': {'groundedness': '', 'utility': '', 'fluency': '', 'notes': ''},\n", " 'overall': '',\n", " 'notes': ''\n", " }\n", "\n", " # chat history\n", " try:\n", " i = 1\n", " while not pd.isna(row[f'user message {i}']):\n", " task['chat_history'].append([row[f'user message {i}'], row[f'bot message {i}']])\n", " i += 1\n", " except:\n", " pass\n", "\n", " # search query\n", " if not pd.isna(row['search_2']):\n", " task['search_query'] = f'{row[\"search_1\"]}\\n{row[\"search_2\"]}'\n", " else:\n", " task['search_query'] = f'{row[\"search_1\"]}'\n", "\n", " # search results\n", " try:\n", " i = 1\n", " while not pd.isna(row[f'source {i}']):\n", " task['search_results'].append(row[f'source {i}'])\n", " i += 1\n", " except:\n", " pass\n", "\n", " # save the task\n", " with open(f'./data/demo_task_{idx+1}.json', 'w', encoding='utf-8') as task_file:\n", " json.dump(task, task_file, ensure_ascii=False, indent=4)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d293fcca-659d-41c5-b043-15fd2e57b216", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "9ab3763f-fa7b-406b-9bc4-22bc4f7a4ea3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "d5023979-626b-4135-8805-3de1a846586e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.1" } }, "nbformat": 4, "nbformat_minor": 5 }