File size: 5,833 Bytes
9f54a3b
 
c2231bb
dc141e7
 
142827c
b7f505b
dc141e7
 
 
 
 
 
 
b7f505b
dc141e7
 
b7f505b
dc141e7
 
b7f505b
dc141e7
 
b7f505b
dc141e7
 
 
 
9f57726
 
dc141e7
 
9f57726
 
 
dc141e7
 
 
9f57726
dc141e7
 
 
 
 
 
 
 
9f57726
dc141e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97edc95
9f57726
dc808bb
 
 
9f57726
dc808bb
 
 
dc141e7
dc808bb
 
dc141e7
 
 
 
 
 
dc808bb
dc141e7
dc808bb
 
 
 
dc141e7
dc808bb
 
 
 
dc141e7
 
 
 
 
97edc95
dc141e7
97edc95
dc808bb
dc141e7
 
 
dc808bb
dc141e7
dc808bb
 
 
 
 
 
 
dc141e7
dc808bb
97edc95
dc141e7
 
 
9f57726
 
9f54a3b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
from openai import OpenAI

# Initialize session state
if 'messages' not in st.session_state:
    st.session_state.messages = []

# Function to generate system prompt based on user inputs
def create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels):
    system_prompt = f"You are a professional {classification_type.lower()} expert. Your role is to generate exactly {num_to_generate} data examples for {domain}. "
    system_prompt += f"Each example should consist of between {min_words} and {max_words} words. "
    system_prompt += "Use the following labels: " + ", ".join(labels) + ". Please do not add any extra commentary or explanation. "
    system_prompt += "Format each example like this: \nExample: <text>, Label: <label>\n"
    return system_prompt

# OpenAI client setup (replace with your OpenAI API credentials)
client = OpenAI(api_key='YOUR_API_KEY')

# App title
st.title("Data Generation for Classification")

# Choice between Data Generation or Data Labeling
mode = st.radio("Choose Task:", ["Data Generation", "Data Labeling"])

if mode == "Data Generation":
    # Step 1: Choose Classification Type
    classification_type = st.radio(
        "Select Classification Type:",
        ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"]
    )
    
    # Step 2: Choose labels based on classification type
    if classification_type == "Sentiment Analysis":
        labels = ["Positive", "Negative", "Neutral"]
    elif classification_type == "Binary Classification":
        class1 = st.text_input("Enter First Class for Binary Classification")
        class2 = st.text_input("Enter Second Class for Binary Classification")
        labels = [class1, class2]
    elif classification_type == "Multi-Class Classification":
        num_classes = st.slider("Number of Classes (Max 10):", 2, 10, 3)
        labels = [st.text_input(f"Enter Class {i+1}") for i in range(num_classes)]
    
    # Step 3: Choose the domain
    domain = st.radio(
        "Select Domain:",
        ["Restaurant reviews", "E-commerce reviews", "Custom"]
    )
    if domain == "Custom":
        domain = st.text_input("Enter Custom Domain")
    
    # Step 4: Specify example length (min and max words)
    min_words = st.slider("Minimum Words per Example", 10, 90, 20)
    max_words = st.slider("Maximum Words per Example", 10, 90, 40)
    
    # Step 5: Ask if user wants few-shot examples
    use_few_shot = st.checkbox("Use Few-Shot Examples?")
    
    few_shot_examples = []
    if use_few_shot:
        num_few_shots = st.slider("Number of Few-Shot Examples (Max 5):", 1, 5, 2)
        for i in range(num_few_shots):
            example_text = st.text_area(f"Enter Example {i+1} Text")
            example_label = st.selectbox(f"Select Label for Example {i+1}", labels)
            few_shot_examples.append(f"Example: {example_text}, Label: {example_label}")
    
    # Step 6: Specify the number of examples to generate
    num_to_generate = st.number_input("Number of Examples to Generate", min_value=1, max_value=50, value=10)
    
    # Step 7: Generate system prompt based on the inputs
    system_prompt = create_system_prompt(classification_type, num_to_generate, domain, min_words, max_words, labels)
    
    if st.button("Generate Examples"):
        all_generated_examples = []
        remaining_examples = num_to_generate

        with st.spinner("Generating..."):
            while remaining_examples > 0:
                chunk_size = min(remaining_examples, 5)
                try:
                    # Add system and user messages to session state
                    st.session_state.messages.append({"role": "system", "content": system_prompt})

                    # Add few-shot examples to the system prompt
                    if few_shot_examples:
                        for example in few_shot_examples:
                            st.session_state.messages.append({"role": "user", "content": example})

                    # Stream API request to generate examples
                    stream = client.chat.completions.create(
                        model="gpt-3.5-turbo",
                        messages=[
                            {"role": m["role"], "content": m["content"]}
                            for m in st.session_state.messages
                        ],
                        temperature=0.7,
                        stream=True,
                        max_tokens=3000,
                    )

                    # Capture streamed response
                    response = ""
                    for chunk in stream:
                        if 'content' in chunk['choices'][0]['delta']:
                            response += chunk['choices'][0]['delta']['content']

                    # Split response into individual examples by "Example: "
                    generated_examples = response.split("Example: ")[1:chunk_size+1]  # Extract up to the chunk size

                    # Clean up the extracted examples
                    cleaned_examples = [f"Example {i+1}: {ex.strip()}" for i, ex in enumerate(generated_examples)]
                    
                    # Store the new examples
                    all_generated_examples.extend(cleaned_examples)
                    remaining_examples -= chunk_size

                except Exception as e:
                    st.error("Error during generation.")
                    st.write(e)
                    break

        # Display all generated examples properly formatted
        for idx, example in enumerate(all_generated_examples):
            st.write(f"Example {idx+1}: {example.strip()}")
        
        # Clear session state to avoid repetition of old prompts
        st.session_state.messages = []  # Reset after each generation