Spaces:
Sleeping
Sleeping
""" Simple Chatbot | |
@author: Nigel Gebodh | |
@email: [email protected] | |
""" | |
import numpy as np | |
import streamlit as st | |
from openai import OpenAI | |
import os | |
from dotenv import load_dotenv | |
streamlit run | |
load_dotenv() | |
# Initialize the client | |
client = OpenAI( | |
base_url="https://api-inference.huggingface.co/v1", | |
api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN') # Replace with your token | |
) | |
# Function to reset conversation | |
def reset_conversation(): | |
st.session_state.conversation = [] | |
st.session_state.messages = [] | |
return None | |
# Initialize session state for 'messages' if it doesn't exist | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Define classification options | |
classification_types = ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] | |
# Start with a selection between data generation or labeling | |
st.sidebar.write("Choose Task:") | |
task = st.sidebar.radio("Do you want to generate data or label data?", ("Data Generation", "Data Labeling")) | |
# If the user selects Data Labeling | |
if task == "Data Labeling": | |
st.sidebar.write("Choose Classification Type:") | |
classification_type = st.sidebar.radio("Select a classification type:", classification_types) | |
# Handle Sentiment Analysis | |
if classification_type == "Sentiment Analysis": | |
st.sidebar.write("Classes: Positive, Negative, Neutral (fixed)") | |
class_labels = ["Positive", "Negative", "Neutral"] | |
# Handle Binary Classification | |
elif classification_type == "Binary Classification": | |
class_1 = st.sidebar.text_input("Enter Class 1:") | |
class_2 = st.sidebar.text_input("Enter Class 2:") | |
class_labels = [class_1, class_2] | |
# Handle Multi-Class Classification | |
elif classification_type == "Multi-Class Classification": | |
class_labels = [] | |
for i in range(1, 11): # Allow up to 10 classes | |
label = st.sidebar.text_input(f"Enter Class {i} (leave blank to stop):") | |
if label: | |
class_labels.append(label) | |
else: | |
break | |
# Domain selection | |
st.sidebar.write("Specify the Domain:") | |
domain = st.sidebar.radio("Choose a domain:", ("Restaurant Reviews", "E-commerce Reviews", "Custom")) | |
if domain == "Custom": | |
domain = st.sidebar.text_input("Enter Custom Domain:") | |
# Specify example length | |
st.sidebar.write("Specify the Length of Examples:") | |
min_words = st.sidebar.number_input("Minimum word count (10 to 90):", 10, 90, 10) | |
max_words = st.sidebar.number_input("Maximum word count (10 to 90):", min_words, 90, 50) | |
# Few-shot examples option | |
use_few_shot = st.sidebar.radio("Do you want to use few-shot examples?", ("Yes", "No")) | |
few_shot_examples = [] | |
if use_few_shot == "Yes": | |
num_examples = st.sidebar.number_input("How many few-shot examples? (1 to 5)", 1, 5, 1) | |
for i in range(num_examples): | |
example_text = st.text_area(f"Enter example {i+1}:") | |
example_label = st.selectbox(f"Select the label for example {i+1}:", class_labels) | |
few_shot_examples.append({"text": example_text, "label": example_label}) | |
# Generate the system prompt based on classification type | |
if classification_type == "Sentiment Analysis": | |
system_prompt = f"You are a propositional sentiment analysis expert. Your role is to generate sentiment analysis reviews based on the data entered and few-shot examples provided, if any, for the domain '{domain}'." | |
elif classification_type == "Binary Classification": | |
system_prompt = f"You are an expert in binary classification. Your task is to label examples for the domain '{domain}' with either '{class_1}' or '{class_2}', based on the data provided." | |
else: # Multi-Class Classification | |
system_prompt = f"You are an expert in multi-class classification. Your role is to label examples for the domain '{domain}' using the provided class labels." | |
st.sidebar.write("System Prompt:") | |
st.sidebar.write(system_prompt) | |
# Step-by-step thinking | |
st.sidebar.write("Generated Data:") | |
st.sidebar.write("Think step by step to ensure accuracy in classification.") | |
# Accept user input for generating or labeling data | |
if prompt := st.chat_input(f"Hi, I'm ready to help with {classification_type} for {domain}. Ask me a question or provide data to classify."): | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
try: | |
# Stream the response from the model | |
stream = client.chat.completions.create( | |
model="meta-llama/Meta-Llama-3-8B-Instruct", | |
messages=[ | |
{"role": m["role"], "content": m["content"]} | |
for m in st.session_state.messages | |
], | |
temperature=0.5, | |
stream=True, | |
max_tokens=3000, | |
) | |
response = st.write_stream(stream) | |
except Exception as e: | |
response = "😵💫 Something went wrong. Try again later." | |
st.write(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# If the user selects Data Generation | |
else: | |
st.sidebar.write("This feature will allow you to generate new data. Coming soon!") | |