|
""" Simple Chatbot |
|
@author: Nigel Gebodh |
|
@email: [email protected] |
|
|
|
""" |
|
import numpy as np |
|
import streamlit as st |
|
from openai import OpenAI |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
|
|
client = OpenAI( |
|
base_url="https://api-inference.huggingface.co/v1", |
|
api_key=os.environ.get('HUGGINGFACEHUB_API_TOKEN') |
|
) |
|
|
|
|
|
def reset_conversation(): |
|
st.session_state.conversation = [] |
|
st.session_state.messages = [] |
|
return None |
|
|
|
|
|
classification_types = ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"] |
|
|
|
|
|
st.sidebar.write("Choose Task:") |
|
task = st.sidebar.radio("Do you want to generate data or label data?", ("Data Generation", "Data Labeling")) |
|
|
|
|
|
if task == "Data Labeling": |
|
st.sidebar.write("Choose Classification Type:") |
|
classification_type = st.sidebar.radio("Select a classification type:", classification_types) |
|
|
|
|
|
if classification_type == "Sentiment Analysis": |
|
st.sidebar.write("Classes: Positive, Negative, Neutral (fixed)") |
|
class_labels = ["Positive", "Negative", "Neutral"] |
|
|
|
|
|
elif classification_type == "Binary Classification": |
|
class_1 = st.sidebar.text_input("Enter Class 1:") |
|
class_2 = st.sidebar.text_input("Enter Class 2:") |
|
class_labels = [class_1, class_2] |
|
|
|
|
|
elif classification_type == "Multi-Class Classification": |
|
class_labels = [] |
|
for i in range(1, 11): |
|
label = st.sidebar.text_input(f"Enter Class {i} (leave blank to stop):") |
|
if label: |
|
class_labels.append(label) |
|
else: |
|
break |
|
|
|
|
|
st.sidebar.write("Specify the Domain:") |
|
domain = st.sidebar.radio("Choose a domain:", ("Restaurant Reviews", "E-commerce Reviews", "Custom")) |
|
if domain == "Custom": |
|
domain = st.sidebar.text_input("Enter Custom Domain:") |
|
|
|
|
|
st.sidebar.write("Specify the Length of Examples:") |
|
min_words = st.sidebar.number_input("Minimum word count (10 to 90):", 10, 90, 10) |
|
max_words = st.sidebar.number_input("Maximum word count (10 to 90):", min_words, 90, 50) |
|
|
|
|
|
use_few_shot = st.sidebar.radio("Do you want to use few-shot examples?", ("Yes", "No")) |
|
few_shot_examples = [] |
|
if use_few_shot == "Yes": |
|
num_examples = st.sidebar.number_input("How many few-shot examples? (1 to 5)", 1, 5, 1) |
|
for i in range(num_examples): |
|
example_text = st.text_area(f"Enter example {i+1}:") |
|
example_label = st.selectbox(f"Select the label for example {i+1}:", class_labels) |
|
few_shot_examples.append({"text": example_text, "label": example_label}) |
|
|
|
|
|
if classification_type == "Sentiment Analysis": |
|
system_prompt = f"You are a propositional sentiment analysis expert. Your role is to generate sentiment analysis reviews based on the data entered and few-shot examples provided, if any, for the domain '{domain}'." |
|
elif classification_type == "Binary Classification": |
|
system_prompt = f"You are an expert in binary classification. Your task is to label examples for the domain '{domain}' with either '{class_1}' or '{class_2}', based on the data provided." |
|
else: |
|
system_prompt = f"You are an expert in multi-class classification. Your role is to label examples for the domain '{domain}' using the provided class labels." |
|
|
|
st.sidebar.write("System Prompt:") |
|
st.sidebar.write(system_prompt) |
|
|
|
|
|
st.sidebar.write("Generated Data:") |
|
st.sidebar.write("Think step by step to ensure accuracy in classification.") |
|
|
|
|
|
if prompt := st.chat_input(f"Hi, I'm ready to help with {classification_type} for {domain}. Ask me a question or provide data to classify."): |
|
|
|
|
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
try: |
|
|
|
stream = client.chat.completions.create( |
|
model="meta-llama/Meta-Llama-3-8B-Instruct", |
|
messages=[ |
|
{"role": m["role"], "content": m["content"]} |
|
for m in st.session_state.messages |
|
], |
|
temperature=0.5, |
|
stream=True, |
|
max_tokens=3000, |
|
) |
|
|
|
response = st.write_stream(stream) |
|
|
|
except Exception as e: |
|
response = "😵💫 Something went wrong. Try again later." |
|
st.write(response) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
else: |
|
st.sidebar.write("This feature will allow you to generate new data. Coming soon!") |
|
|
|
|