{ "cells": [ { "cell_type": "code", "execution_count": 19, "id": "6b57ced9-62ee-44a0-a895-6ed288f970ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7878\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import gradio as gr\n", "import requests\n", "from PIL import Image\n", "from openai import OpenAI\n", "\n", "\n", "def run_demo():\n", " \"\"\"Setup the app interface and launch it.\"\"\"\n", " with gr.Blocks() as app:\n", "\n", " gr.Markdown('# Mental Health Nudging with Generative AI Demo')\n", " with gr.Row():\n", "\n", " # input features\n", " with gr.Column(scale=1):\n", "\n", " # demographics\n", " gender = gr.Radio(label='Gender', value='Unspecified',\n", " choices=['Male', 'Female', 'Non-Binary', 'Unspecified'])\n", " age = gr.Slider(label='Age', minimum=18, maximum=80, step=1)\n", " race = gr.Radio(label='Race', value='Unspecified',\n", " choices=['White', 'Hispanic', 'Black', 'Asian', 'Unspecified'])\n", "\n", " # symptoms\n", " disorders = ['Sadness', 'Inability to concentrate', 'Excessive worrying', 'Extreme mood changes',\n", " 'Withdrawal from friends/activities', 'Tiredness', 'Hallucinations', 'Addiction',\n", " 'Lack of appetite', 'Increased appetite']\n", " symptoms = gr.CheckboxGroup(label='Symptoms', choices=disorders)\n", "\n", " # interests\n", " interests = gr.Textbox(label='Interests', placeholder='Comma-separated list of interests...')\n", "\n", " # submit button\n", " submit_button = gr.Button('Generate Nudge')\n", "\n", " # resulting nudge\n", " with gr.Column(scale=1):\n", " nudge_image = gr.Image(label='Nudge Image')\n", " nudge_message = gr.Textbox(label='Nudge Message')\n", "\n", " # submit parameters for nudge generation\n", " inputs = [gender, age, race, interests, symptoms]\n", " outputs = [nudge_image, nudge_message]\n", " submit_button.click(fn=generate, inputs=inputs, outputs=outputs)\n", "\n", " # launch the app\n", " gr.close_all()\n", " app.queue(default_concurrency_limit=None)\n", " app.launch()\n", "\n", "\n", "def generate(gender, age, race, interests, symptoms):\n", " \"\"\"Generate nudging image and message for the given person.\"\"\"\n", " nudge_message = generate_nudge_message(gender, age, interests, symptoms)\n", " nudge_image = generate_nudge_image(gender, age, race, nudge_message, random=False)\n", " return nudge_image, nudge_message\n", "\n", "\n", "def generate_nudge_message(gender, age, interests, symptoms):\n", " \"\"\"Generate a message for a given person.\"\"\"\n", " # construct description of the person\n", " desc = f'A {age} year old '\n", " if gender == 'Male':\n", " desc += 'man.'\n", " elif gender == 'Female':\n", " desc += 'woman.'\n", " elif gender == 'Non-Binary':\n", " desc += 'non-binary person.'\n", " else:\n", " desc += 'person.'\n", " if interests:\n", " desc += f' They like {interests}.'\n", " if symptoms:\n", " desc += f' They have the following mental health symptoms: {\", \".join(map(str.lower, symptoms))}.'\n", " else:\n", " desc += f' They do not have any mental health symptoms.'\n", " print(f'{desc=}')\n", "\n", " # generate nudge message\n", " system_prompt = '''You are writing motivational text messages to help people with their mental health. \\\n", "Messages should be friendly and positive, but also professional and super short. \\\n", "You are limited on space. \\\n", "Messages should be written at the reading level of an eighth grader. \\\n", "Word choice should be short and simple so everyone can understand. \\n\\n\n", "You will be given some basic information about the person you are addressing. \\\n", "DO NOT reference all of their likes if there are more than two. Be discerning. \\\n", "You should try to use the person's information to give them relevant and actionable tips for improving their mental health symptoms.'''\n", " user_prompt = f'Write a short inspirational message for the person with the following description:\\n\\n{desc}'\n", "\n", " print(f'\\n\\n{system_prompt}')\n", " print(f'\\n\\n{user_prompt}')\n", " \n", " messages = [{'role': 'system', 'content': f'{system_prompt}'},\n", " {'role': 'user', 'content': f'{user_prompt}'}]\n", " completion = client.chat.completions.create(messages=messages, model='gpt-3.5-turbo', temperature=.5)\n", " nudge_message = completion.choices[0].message.content\n", " return nudge_message\n", "\n", "\n", "def generate_nudge_image(gender, age, race, nudge_message, random=False):\n", " \"\"\"Generate an image for a given person and message.\"\"\"\n", " if random:\n", " return Image.fromarray(np.random.randint(0, 255, (100, 100, 3), dtype='uint8'), 'RGB')\n", "\n", " # construct description of the person\n", " desc = f'The person is a {age} year old '\n", " if race != 'Unspecified':\n", " desc += f'{race.lower()} '\n", " if gender == 'Male':\n", " desc += 'man.'\n", " elif gender == 'Female':\n", " desc += 'woman.'\n", " elif gender == 'Non-Binary':\n", " desc += 'non-binary person.'\n", " else:\n", " desc += 'person.'\n", " prompt = 'Illustrate one simple, inspirational, fun image to help a person with their mental health. Do not include text. '\\\n", " + f'The style is cute and illustrative. {desc} '\\\n", " + f'The image should align with the following message:\\n\\n{nudge_message}'\n", "\n", " print(f'\\n{prompt}')\n", " response = client.images.generate(prompt=prompt, model='dall-e-3')#, response_format='b64_json')\n", " print(response)\n", " nudge_image = Image.open(requests.get(response.data[0].url, stream=True).raw)\n", "\n", " return nudge_image\n", "\n", "\n", "if __name__ == '__main__':\n", " client = OpenAI()\n", " run_demo()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6bf1dd4a-b7c5-496d-9a7a-b4a392e185e2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "83da6055-6ca8-44ee-a64f-be3b29fc400a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "24bbd8a2-600e-4b46-aa58-95caa76c2c88", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7205584e-9526-42b0-8f8f-a7c3901831d1", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c7cfcc42-b369-465c-902f-e01b4e017fb8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "392dbdda-2a94-4671-b935-4a4706d5b5f6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ec9c7160-a50c-4f2d-9450-70586157ca05", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "08545064-3b81-4216-bce7-df58bdef7aeb", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "dd6261bf-f3bd-45cb-bcc7-d517d44f7282", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1e14977a-be46-4d92-a512-cf19bb39d4af", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "86b9de25-7c42-4e6f-bbf5-b0cb78fd0331", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ec4e3629-07a5-4121-8752-8f6cfe20b561", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }