############################################################################################################################# # Filename : app.py # Description: A Streamlit application to showcase the importance of Responsible AI in LLMs. # Author : Georgios Ioannou # # Copyright © 2024 by Georgios Ioannou ############################################################################################################################# # Import libraries. import os import requests import streamlit as st import streamlit.components.v1 as components from dataclasses import dataclass from dotenv import find_dotenv, load_dotenv from huggingface_hub import InferenceClient from langchain.callbacks import get_openai_callback from langchain.chains import ConversationChain from langchain.llms import OpenAI from policies import complex_policy, simple_policy from typing import Literal ############################################################################################################################# # Load environment variable(s). load_dotenv(find_dotenv()) OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY") ############################################################################################################################# @dataclass class Message: """Class for keeping track of a chat message.""" origin: Literal["human", "ai"] message: str ############################################################################################################################# # Initialize Hugging Face clients. def initialize_hf_clients(): client = InferenceClient(api_key=HUGGINGFACE_API_KEY) gpt2_api_url = "https://api-inference.huggingface.co/models/openai-community/gpt2" headers = {"Authorization": f"Bearer {HUGGINGFACE_API_KEY}"} return client, gpt2_api_url, headers ############################################################################################################################# # Hugging Face model inference functions. def qwen_inference(prompt): client, _, _ = initialize_hf_clients() messages = [{"role": "user", "content": prompt}] try: response = client.chat.completions.create( model="Qwen/Qwen2.5-1.5B-Instruct", messages=messages, max_tokens=500 ) return response.choices[0].message.content except Exception as e: return f"Error with Qwen inference: {str(e)}" def gpt2_inference(prompt): _, api_url, headers = initialize_hf_clients() try: response = requests.post(api_url, headers=headers, json={"inputs": prompt}) return response.json()[0]["generated_text"] except Exception as e: return f"Error with GPT-2 inference: {str(e)}" ############################################################################################################################# # Initialize session state. def initialize_session_state(): if "history" not in st.session_state: st.session_state.history = [] if "token_count" not in st.session_state: st.session_state.token_count = 0 if "policy" not in st.session_state: st.session_state.policy = "No Policy" if "selected_ai_icon" not in st.session_state: st.session_state.selected_ai_icon = "a1.png" if "selected_user_icon" not in st.session_state: st.session_state.selected_user_icon = "man.png" ############################################################################################################################# # Function to apply local CSS. def local_css(file_name): with open(file_name) as f: st.markdown(f"", unsafe_allow_html=True) ############################################################################################################################# def on_click_callback(): with get_openai_callback() as cb: human_prompt = st.session_state.human_prompt # Policy + User Query Text. role = "User" task = f""" Task: Check if there is unsafe content in '{role}' messages in conversations according our safety policy with the below categories. """ output_format = f""" Provide your safety assessment for '{role}' in the above conversation: - First line must read 'safe' or 'unsafe'. - If unsafe, a second line must include a comma-separated list of violated categories. """ query = human_prompt conversation = f""" User: {query} """ if st.session_state.policy == "Simple Policy": prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}" elif st.session_state.policy == "Complex Policy": prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}" else: prompt = human_prompt # Safety check 1 - Input check. if ( "gpt" in st.session_state.model.lower() and "gpt2" not in st.session_state.model.lower() ): llm_response_safety_check_1 = st.session_state.conversation.run(prompt) st.session_state.token_count += cb.total_tokens elif "qwen" in st.session_state.model.lower(): llm_response_safety_check_1 = qwen_inference(prompt) st.session_state.token_count += cb.total_tokens else: # gpt2. llm_response_safety_check_1 = gpt2_inference(prompt) st.session_state.token_count += cb.total_tokens st.session_state.history.append(Message("human", human_prompt)) if "unsafe" in llm_response_safety_check_1.lower(): st.session_state.history.append(Message("ai", llm_response_safety_check_1)) return # Get model response. if ( "gpt" in st.session_state.model.lower() and "gpt2" not in st.session_state.model.lower() ): conversation_chain = ConversationChain( llm=OpenAI( temperature=0.2, openai_api_key=OPENAI_API_KEY, model_name=st.session_state.model, ) ) llm_response = conversation_chain.run(human_prompt) st.session_state.token_count += cb.total_tokens elif "qwen" in st.session_state.model.lower(): llm_response = qwen_inference(human_prompt) st.session_state.token_count += cb.total_tokens else: # gpt2. llm_response = gpt2_inference(human_prompt) st.session_state.token_count += cb.total_tokens # Safety check 2 - Output check. query = llm_response conversation = f""" User: {query} """ if st.session_state.policy == "Simple Policy": prompt = f"{task}\n{simple_policy}\n{conversation}\n{output_format}" elif st.session_state.policy == "Complex Policy": prompt = f"{task}\n{complex_policy}\n{conversation}\n{output_format}" else: prompt = llm_response if ( "gpt" in st.session_state.model.lower() and "gpt2" not in st.session_state.model.lower() ): llm_response_safety_check_2 = st.session_state.conversation.run(prompt) st.session_state.token_count += cb.total_tokens elif "qwen" in st.session_state.model.lower(): llm_response_safety_check_2 = qwen_inference(prompt) st.session_state.token_count += cb.total_tokens else: # gpt2. llm_response_safety_check_2 = gpt2_inference(prompt) st.session_state.token_count += cb.total_tokens if "unsafe" in llm_response_safety_check_2.lower(): st.session_state.history.append( Message( "ai", "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!", ) ) else: st.session_state.history.append(Message("ai", llm_response)) ############################################################################################################################# def main(): initialize_session_state() # Page title and favicon. st.set_page_config(page_title="Responsible AI", page_icon="⚖️") # Load CSS. local_css("./static/styles/styles.css") # Title. title = f"""

Responsible AI

""" st.markdown(title, unsafe_allow_html=True) # Subtitle 1. subtitle1 = f"""

Showcase the importance of Responsible AI in LLMs Using Policies

""" st.markdown(subtitle1, unsafe_allow_html=True) # Subtitle 2. subtitle2 = f"""

CUNY Tech Prep Tutorial 6

""" st.markdown(subtitle2, unsafe_allow_html=True) # Image. image = "./static/ctp.png" left_co, cent_co, last_co = st.columns(3) with cent_co: st.image(image=image) # Sidebar dropdown menu for Models. models = [ "gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gpt-4-turbo", "gpt-4", "Qwen2.5-1.5B-Instruct", "gpt2", ] selected_model = st.sidebar.selectbox("Select Model:", models) st.sidebar.markdown( f"Current Model: {selected_model}", unsafe_allow_html=True, ) st.session_state.model = selected_model if "gpt" in selected_model.lower() and "gpt2" not in selected_model.lower(): st.session_state.conversation = ConversationChain( llm=OpenAI( temperature=0.2, openai_api_key=OPENAI_API_KEY, model_name=st.session_state.model, ), ) # Sidebar dropdown menu for Policies. policies = ["No Policy", "Complex Policy", "Simple Policy"] selected_policy = st.sidebar.selectbox("Select Policy:", policies) st.sidebar.markdown( f"Current Policy: {selected_policy}", unsafe_allow_html=True, ) st.session_state.policy = selected_policy # Sidebar dropdown menu for AI Icons. ai_icons = ["AI 1", "AI 2"] selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons) st.sidebar.markdown( f"Current AI Icon: {selected_ai_icon}", unsafe_allow_html=True, ) if selected_ai_icon == "AI 1": st.session_state.selected_ai_icon = "ai1.png" elif selected_ai_icon == "AI 2": st.session_state.selected_ai_icon = "ai2.png" # Sidebar dropdown menu for User Icons. user_icons = ["Man", "Woman"] selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons) st.sidebar.markdown( f"Current User Icon: {selected_user_icon}", unsafe_allow_html=True, ) if selected_user_icon == "Man": st.session_state.selected_user_icon = "man.png" elif selected_user_icon == "Woman": st.session_state.selected_user_icon = "woman.png" # Chat interface. chat_placeholder = st.container() prompt_placeholder = st.form("chat-form") token_placeholder = st.empty() with chat_placeholder: for chat in st.session_state.history: div = f"""
​{chat.message}
""" st.markdown(div, unsafe_allow_html=True) for _ in range(3): st.markdown("") # User prompt. with prompt_placeholder: st.markdown("**Chat**") cols = st.columns((6, 1)) cols[0].text_input( "Chat", placeholder="What is your question?", label_visibility="collapsed", key="human_prompt", ) cols[1].form_submit_button( "Submit", type="primary", on_click=on_click_callback, ) token_placeholder.caption(f"Used {st.session_state.token_count} tokens\n") # GitHub repository link. st.markdown( f"""

Check out our GitHub repository

""", unsafe_allow_html=True, ) # Enter key handler. components.html( """ """, height=0, width=0, ) if __name__ == "__main__": main()