{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Structured Responses\n", "\n", "Here's a public service announcement. There's no law that says you have to ask LLMs for essays, poems or relationship advice.\n", "\n", "Yes, they're great at drumming up long blocks of text. An LLM can spit out a long answer to almost any question. It's how they've been tuned and marketed by companies selling chatbots and more conversational forms of search.\n", "\n", "But they're also great at answering simple questions, a skill that has been overlooked in much of the hoopla that followed the introduction of ChatGPT.\n", "\n", "Here's a example that simply prompts the LLM to answer a straightforward question." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "import os\n", "from rich import print\n", "\n", "# Reuse the Hugging Face client setup from the previous chapter\n", "from huggingface_hub import InferenceClient\n", "api_key = os.getenv(\"HF_TOKEN\")\n", "client = InferenceClient(\n", " token=api_key,\n", ")" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lace that into our request." ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt # NEW\n", " },\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now add a user message that provides the name of a professional sports team." ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt\n", " },\n", " ### <-- NEW \n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Minnesota Twins\",\n", " }\n", " ### -->\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the response." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Major League Baseball (MLB)\n",
       "
\n" ], "text/plain": [ "Major League Baseball \u001b[1m(\u001b[0mMLB\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(response.choices[0].message.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And we'll bet you get the right answer.\n", "\n", "```\n", "Major League Baseball (MLB)\n", "```\n", "\n", "Try another one." ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Minnesota Vikings\", # NEW\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", ")" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
National Football League (NFL)\n",
       "
\n" ], "text/plain": [ "National Football League \u001b[1m(\u001b[0mNFL\u001b[1m)\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(response.choices[0].message.content)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See what we mean?\n", "\n", "```\n", "National Football League (NFL)\n", "```\n", "\n", "This approach can be use to classify large datasets, adding a new column of data that categories text in a way that makes it easier to analyze.\n", "\n", "Let's try it by making a function that will classify whatever team you provide." ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "def classify_team(name):\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\"\"\"\n", "\n", " response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt,\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": name,\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " )\n", "\n", " return response.choices[0].message.content" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A list of teams." ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "team_list = [\"Minnesota Twins\", \"Minnesota Vikings\", \"Minnesota Timberwolves\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, loop through the list and ask the LLM to code them one by one." ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
['Minnesota Twins', 'Major League Baseball (MLB)']\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[32m'Minnesota Twins'\u001b[0m, \u001b[32m'Major League Baseball \u001b[0m\u001b[32m(\u001b[0m\u001b[32mMLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m'\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
['Minnesota Vikings', 'National Football League (NFL)']\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[32m'Minnesota Vikings'\u001b[0m, \u001b[32m'National Football League \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNFL\u001b[0m\u001b[32m)\u001b[0m\u001b[32m'\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
['Minnesota Timberwolves', 'National Basketball Association (NBA)']\n",
       "
\n" ], "text/plain": [ "\u001b[1m[\u001b[0m\u001b[32m'Minnesota Timberwolves'\u001b[0m, \u001b[32m'National Basketball Association \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNBA\u001b[0m\u001b[32m)\u001b[0m\u001b[32m'\u001b[0m\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for team in team_list:\n", " league = classify_team(team)\n", " print([team, league])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Due its probabilistic nature, the LLM can sometimes return slight variations on the same answer. You can prevent this by adding a validation system that will only accept responses from a pre-defined list." ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "def classify_team(name):\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\n", "Your responses must come from the following list:\n", "- Major League Baseball (MLB)\n", "- National Football League (NFL)\n", "- National Basketball Association (NBA)\n", "\"\"\"\n", "\n", " response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt,\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": name,\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " )\n", "\n", " answer = response.choices[0].message.content\n", " ### <-- NEW\n", " acceptable_answers = [\n", " \"Major League Baseball (MLB)\",\n", " \"National Football League (NFL)\",\n", " \"National Basketball Association (NBA)\",\n", " ]\n", " if answer not in acceptable_answers:\n", " raise ValueError(f\"{answer} not in list of acceptable answers\")\n", " ### -->\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, ask it for a team that's not in one of those leagues. You should get an error." ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "ename": "ValueError", "evalue": "National Hockey League (NHL) \n\nNote: The provided team doesn't fit into the specified leagues (MLB, NFL, NBA), as the Minnesota Wild is a part of the National Hockey League. not in list of acceptable answers", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[51], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mclassify_team\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mMinnesota Wild\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", "Cell \u001b[0;32mIn[50], line 37\u001b[0m, in \u001b[0;36mclassify_team\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 31\u001b[0m acceptable_answers \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 32\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMajor League Baseball (MLB)\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 33\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNational Football League (NFL)\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 34\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNational Basketball Association (NBA)\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 35\u001b[0m ]\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m answer \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m acceptable_answers:\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00manswer\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m not in list of acceptable answers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m answer\n", "\u001b[0;31mValueError\u001b[0m: National Hockey League (NHL) \n\nNote: The provided team doesn't fit into the specified leagues (MLB, NFL, NBA), as the Minnesota Wild is a part of the National Hockey League. not in list of acceptable answers" ] } ], "source": [ "classify_team(\"Minnesota Wild\")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "def classify_team(name):\n", " # Last sentence is the prompt is new\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\n", "Your responses must come from the following list:\n", "- Major League Baseball (MLB)\n", "- National Football League (NFL)\n", "- National Basketball Association (NBA)\n", "\n", "If the team's league is not on the list, you should label them as \"Other\".\n", "\"\"\"\n", " response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt,\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": name,\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " )\n", "\n", " answer = response.choices[0].message.content\n", "\n", " acceptable_answers = [\n", " \"Major League Baseball (MLB)\",\n", " \"National Football League (NFL)\",\n", " \"National Basketball Association (NBA)\",\n", " \"Other\", # NEW\n", " ]\n", " if answer not in acceptable_answers:\n", " raise ValueError(f\"{answer} not in list of acceptable answers\")\n", "\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now try the Minnesota Wild again." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Other'" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classify_team(\"Minnesota Wild\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And you'll get the answer you expect.\n", "\n", "```\n", "'Other'\n", "```\n", "\n", "Most LLMs are pre-programmed to be creative and generate a range of responses to same prompt. For structured responses like this, we don't want that. We want consistency. So it's a good idea to ask the LLM to be more straightforward by reducing a creativity setting known as `temperature` to zero." ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def classify_team(name):\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\n", "Your responses must come from the following list:\n", "- Major League Baseball (MLB)\n", "- National Football League (NFL)\n", "- National Basketball Association (NBA)\n", "\n", "If the team's league is not on the list, you should label them as \"Other\".\n", "\"\"\"\n", "\n", " response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt,\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": name,\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " temperature=0, # NEW\n", " )\n", "\n", " answer = response.choices[0].message.content\n", "\n", " acceptable_answers = [\n", " \"Major League Baseball (MLB)\",\n", " \"National Football League (NFL)\",\n", " \"National Basketball Association (NBA)\",\n", " \"Other\",\n", " ]\n", " if answer not in acceptable_answers:\n", " raise ValueError(f\"{answer} not in list of acceptable answers\")\n", "\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also increase reliability by priming the LLM with examples of the type of response you want. This technique is called [\"few shot prompting\"](https://www.ibm.com/think/topics/few-shot-prompting). In this style of prompting, which can feel like a strange form of roleplaying, you provide both the \"user\" input as well as the \"assistant\" response you want the LLM to mimic.\n", "\n", "Here's how it's done:" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "def classify_team(name):\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "\n", "I will provide the name of a professional sports team.\n", "\n", "You will reply with the sports league in which they compete.\n", "\n", "Your responses must come from the following list:\n", "- Major League Baseball (MLB)\n", "- National Football League (NFL)\n", "- National Basketball Association (NBA)\n", "\n", "If the team's league is not on the list, you should label them as \"Other\".\n", "\"\"\"\n", "\n", " response = client.chat.completions.create(\n", " messages=[\n", " {\n", " \"role\": \"system\",\n", " \"content\": prompt,\n", " },\n", " ### <-- NEW \n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Los Angeles Rams\",\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \"National Football League (NFL)\",\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Los Angeles Dodgers\",\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \" Major League Baseball (MLB)\",\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Los Angeles Lakers\",\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \"National Basketball Association (NBA)\",\n", " },\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"Los Angeles Kings\",\n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": \"Other\",\n", " },\n", " ### -->\n", " {\n", " \"role\": \"user\",\n", " \"content\": name,\n", " }\n", " ],\n", " model=\"meta-llama/Llama-3.3-70B-Instruct\",\n", " temperature=0,\n", " )\n", "\n", " answer = response.choices[0].message.content\n", "\n", " acceptable_answers = [\n", " \"Major League Baseball (MLB)\",\n", " \"National Football League (NFL)\",\n", " \"National Basketball Association (NBA)\",\n", " \"Other\",\n", " ]\n", " if answer not in acceptable_answers:\n", " raise ValueError(f\"{answer} not in list of acceptable answers\")\n", "\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also ask the function to automatically retry if it doesn't get a valid response. This will give the LLM a second chance to get it right in cases where it gets too creative.\n", "\n", "To do that, we'll return installation step and in the `retry` package." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "%pip install rich ipywidgets retry" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now import the `retry` package." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "from rich import print\n", "import requests\n", "from huggingface_hub import InferenceClient\n", "from retry import retry # NEW" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And add the `retry` decorator to the function that will catch the `ValueError` exception and try again, as many times as you specify." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "@retry(ValueError, tries=2, delay=2) # NEW\n", "def classify_team(name):\n", " prompt = \"\"\"\n", "You are an AI model trained to classify text.\n", "...\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**[7. Bulk prompts →](ch7-bulk-prompts.ipynb)**" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 4 }