Wedyan2023 commited on
Commit
f45bbe4
·
verified ·
1 Parent(s): 07c4ad3

Create app8.py

Browse files
Files changed (1) hide show
  1. app8.py +394 -0
app8.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from datetime import datetime
5
+ import random
6
+ from pathlib import Path
7
+ from openai import OpenAI
8
+ from dotenv import load_dotenv
9
+ from langchain_core.prompts import PromptTemplate
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+ ##openai_api_key = os.getenv("OPENAI_API_KEY")
14
+
15
+ # Initialize the client
16
+ client = OpenAI(
17
+ base_url="https://api-inference.huggingface.co/v1",
18
+ api_key=os.environ.get('TOKEN2') # Add your Huggingface token here
19
+ )
20
+
21
+
22
+ # Initialize OpenAI client
23
+ ##client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
24
+
25
+ # Custom CSS for better appearance
26
+ st.markdown("""
27
+ <style>
28
+ .stButton > button {
29
+ width: 100%;
30
+ margin-bottom: 10px;
31
+ background-color: #4CAF50;
32
+ color: white;
33
+ border: none;
34
+ padding: 10px;
35
+ border-radius: 5px;
36
+ }
37
+ .task-button {
38
+ background-color: #2196F3 !important;
39
+ }
40
+ .stSelectbox {
41
+ margin-bottom: 20px;
42
+ }
43
+ .output-container {
44
+ padding: 20px;
45
+ border-radius: 5px;
46
+ border: 1px solid #ddd;
47
+ margin: 10px 0;
48
+ }
49
+ .status-container {
50
+ padding: 10px;
51
+ border-radius: 5px;
52
+ margin: 10px 0;
53
+ }
54
+ .sidebar-info {
55
+ padding: 10px;
56
+ background-color: #f0f2f6;
57
+ border-radius: 5px;
58
+ margin: 10px 0;
59
+ }
60
+ </style>
61
+ """, unsafe_allow_html=True)
62
+
63
+ # Create data directories if they don't exist
64
+ if not os.path.exists('data'):
65
+ os.makedirs('data')
66
+
67
+ def read_csv_with_encoding(file):
68
+ encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
69
+ for encoding in encodings:
70
+ try:
71
+ return pd.read_csv(file, encoding=encoding)
72
+ except UnicodeDecodeError:
73
+ continue
74
+ raise UnicodeDecodeError("Failed to read file with any supported encoding")
75
+
76
+ def save_to_csv(data, filename):
77
+ df = pd.DataFrame(data)
78
+ df.to_csv(f'data/{filename}', index=False)
79
+ return df
80
+
81
+ def load_from_csv(filename):
82
+ try:
83
+ return pd.read_csv(f'data/{filename}')
84
+ except:
85
+ return pd.DataFrame()
86
+
87
+ # Define reset function
88
+ def reset_conversation():
89
+ st.session_state.conversation = []
90
+ st.session_state.messages = []
91
+
92
+ # Initialize session state
93
+ if "messages" not in st.session_state:
94
+ st.session_state.messages = []
95
+
96
+ # Main app title
97
+ st.title("🤖 Text Data Generation & Labeling App")
98
+
99
+ # Sidebar settings
100
+ with st.sidebar:
101
+ st.title("⚙️ Settings")
102
+
103
+ selected_model = st.selectbox(
104
+ "Select Model",
105
+ ["meta-llama/Meta-Llama-3-8B-Instruct"],
106
+ key='model_select'
107
+ )
108
+
109
+ temperature = st.slider(
110
+ "Temperature",
111
+ 0.0, 1.0, 0.5,
112
+ help="Controls randomness in generation"
113
+ )
114
+
115
+ st.button("🔄 Reset Conversation", on_click=reset_conversation)
116
+
117
+ with st.container():
118
+ st.markdown("""
119
+ <div class="sidebar-info">
120
+ <h4>Current Model: {}</h4>
121
+ <p><em>Note: Generated content may be inaccurate or false.</em></p>
122
+ </div>
123
+ """.format(selected_model), unsafe_allow_html=True)
124
+
125
+ # Main content
126
+ col1, col2 = st.columns(2)
127
+
128
+ with col1:
129
+ if st.button("📝 Data Generation", key="gen_button", help="Generate new data"):
130
+ st.session_state.task_choice = "Data Generation"
131
+
132
+ with col2:
133
+ if st.button("🏷️ Data Labeling", key="label_button", help="Label existing data"):
134
+ st.session_state.task_choice = "Data Labeling"
135
+
136
+ if "task_choice" in st.session_state:
137
+ if st.session_state.task_choice == "Data Generation":
138
+ st.header("📝 Data Generation")
139
+
140
+ classification_type = st.selectbox(
141
+ "Classification Type",
142
+ ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"]
143
+ )
144
+
145
+ if classification_type == "Sentiment Analysis":
146
+ labels = ["Positive", "Negative", "Neutral"]
147
+ elif classification_type == "Binary Classification":
148
+ col1, col2 = st.columns(2)
149
+ with col1:
150
+ label_1 = st.text_input("First class", "Positive")
151
+ with col2:
152
+ label_2 = st.text_input("Second class", "Negative")
153
+ labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"]
154
+ else:
155
+ num_classes = st.slider("Number of classes", 3, 10, 3)
156
+ labels = []
157
+ cols = st.columns(3)
158
+ for i in range(num_classes):
159
+ with cols[i % 3]:
160
+ label = st.text_input(f"Class {i+1}", f"Class_{i+1}")
161
+ labels.append(label)
162
+
163
+ domain = st.selectbox("Domain", ["Restaurant reviews", "E-commerce reviews", "Custom"])
164
+ if domain == "Custom":
165
+ domain = st.text_input("Specify custom domain")
166
+
167
+ col1, col2 = st.columns(2)
168
+ with col1:
169
+ min_words = st.number_input("Min words", 10, 90, 20)
170
+ with col2:
171
+ max_words = st.number_input("Max words", min_words, 90, 50)
172
+
173
+ use_few_shot = st.toggle("Use few-shot examples")
174
+ few_shot_examples = []
175
+ if use_few_shot:
176
+ num_examples = st.slider("Number of few-shot examples", 1, 5, 1)
177
+ for i in range(num_examples):
178
+ with st.expander(f"Example {i+1}"):
179
+ content = st.text_area(f"Content", key=f"few_shot_content_{i}")
180
+ label = st.selectbox(f"Label", labels, key=f"few_shot_label_{i}")
181
+ if content and label:
182
+ few_shot_examples.append({"content": content, "label": label})
183
+
184
+ num_to_generate = st.number_input("Number of examples", 1, 100, 10)
185
+ user_prompt = st.text_area("Additional instructions (optional)")
186
+
187
+ # Updated prompt template with word length constraints
188
+ prompt_template = PromptTemplate(
189
+ input_variables=["classification_type", "domain", "num_examples", "min_words", "max_words", "labels", "user_prompt"],
190
+ template=(
191
+ "You are a professional {classification_type} expert tasked with generating examples for {domain}.\n"
192
+ "Use the following parameters:\n"
193
+ "- Generate exactly {num_examples} examples\n"
194
+ "- Each example MUST be between {min_words} and {max_words} words long\n"
195
+ "- Use these labels: {labels}\n"
196
+ "- Generate the examples in this format: 'Example text. Label: [label]'\n"
197
+ "- Do not include word counts or any additional information\n"
198
+ "Additional instructions: {user_prompt}\n\n"
199
+ "Generate numbered examples:"
200
+ )
201
+ )
202
+
203
+ col1, col2 = st.columns(2)
204
+ with col1:
205
+ if st.button("🎯 Generate Examples"):
206
+ with st.spinner("Generating examples..."):
207
+ system_prompt = prompt_template.format(
208
+ classification_type=classification_type,
209
+ domain=domain,
210
+ num_examples=num_to_generate,
211
+ min_words=min_words,
212
+ max_words=max_words,
213
+ labels=", ".join(labels),
214
+ user_prompt=user_prompt
215
+ )
216
+ try:
217
+ stream = client.chat.completions.create(
218
+ model=selected_model,
219
+ messages=[{"role": "system", "content": system_prompt}],
220
+ temperature=temperature,
221
+ stream=True,
222
+ max_tokens=3000,
223
+ )
224
+ response = st.write_stream(stream)
225
+ st.session_state.messages.append({"role": "assistant", "content": response})
226
+ except Exception as e:
227
+ st.error("An error occurred during generation.")
228
+ st.error(f"Details: {e}")
229
+
230
+ with col2:
231
+ if st.button("🔄 Regenerate"):
232
+ st.session_state.messages = st.session_state.messages[:-1] if st.session_state.messages else []
233
+ with st.spinner("Regenerating examples..."):
234
+ system_prompt = prompt_template.format(
235
+ classification_type=classification_type,
236
+ domain=domain,
237
+ num_examples=num_to_generate,
238
+ min_words=min_words,
239
+ max_words=max_words,
240
+ labels=", ".join(labels),
241
+ user_prompt=user_prompt
242
+ )
243
+ try:
244
+ stream = client.chat.completions.create(
245
+ model=selected_model,
246
+ messages=[{"role": "system", "content": system_prompt}],
247
+ temperature=temperature,
248
+ stream=True,
249
+ max_tokens=3000,
250
+ )
251
+ response = st.write_stream(stream)
252
+ st.session_state.messages.append({"role": "assistant", "content": response})
253
+ except Exception as e:
254
+ st.error("An error occurred during regeneration.")
255
+ st.error(f"Details: {e}")
256
+
257
+ elif st.session_state.task_choice == "Data Labeling":
258
+ st.header("🏷️ Data Labeling")
259
+
260
+ classification_type = st.selectbox(
261
+ "Classification Type",
262
+ ["Sentiment Analysis", "Binary Classification", "Multi-Class Classification"],
263
+ key="label_class_type"
264
+ )
265
+
266
+ if classification_type == "Sentiment Analysis":
267
+ labels = ["Positive", "Negative", "Neutral"]
268
+ elif classification_type == "Binary Classification":
269
+ col1, col2 = st.columns(2)
270
+ with col1:
271
+ label_1 = st.text_input("First class", "Positive", key="label_first")
272
+ with col2:
273
+ label_2 = st.text_input("Second class", "Negative", key="label_second")
274
+ labels = [label_1, label_2] if label_1 and label_2 else ["Positive", "Negative"]
275
+ else:
276
+ num_classes = st.slider("Number of classes", 3, 10, 3, key="label_num_classes")
277
+ labels = []
278
+ cols = st.columns(3)
279
+ for i in range(num_classes):
280
+ with cols[i % 3]:
281
+ label = st.text_input(f"Class {i+1}", f"Class_{i+1}", key=f"label_class_{i}")
282
+ labels.append(label)
283
+
284
+ use_few_shot = st.toggle("Use few-shot examples for labeling")
285
+ few_shot_examples = []
286
+ if use_few_shot:
287
+ num_few_shot = st.slider("Number of few-shot examples", 1, 5, 1)
288
+ for i in range(num_few_shot):
289
+ with st.expander(f"Few-shot Example {i+1}"):
290
+ content = st.text_area(f"Content", key=f"label_few_shot_content_{i}")
291
+ label = st.selectbox(f"Label", labels, key=f"label_few_shot_label_{i}")
292
+ if content and label:
293
+ few_shot_examples.append(f"{content}\nLabel: {label}")
294
+
295
+ num_examples = st.number_input("Number of examples to classify", 1, 100, 1)
296
+
297
+ examples_to_classify = []
298
+ if num_examples <= 20:
299
+ for i in range(num_examples):
300
+ example = st.text_area(f"Example {i+1}", key=f"example_{i}")
301
+ if example:
302
+ examples_to_classify.append(example)
303
+ else:
304
+ examples_text = st.text_area(
305
+ "Enter examples (one per line)",
306
+ height=300,
307
+ help="Enter each example on a new line"
308
+ )
309
+ if examples_text:
310
+ examples_to_classify = [ex.strip() for ex in examples_text.split('\n') if ex.strip()]
311
+ if len(examples_to_classify) > num_examples:
312
+ examples_to_classify = examples_to_classify[:num_examples]
313
+
314
+ user_prompt = st.text_area("Additional instructions (optional)", key="label_instructions")
315
+
316
+ # Updated prompt template for labeling
317
+ few_shot_text = "\n\n".join(few_shot_examples) if few_shot_examples else ""
318
+ examples_text = "\n".join(f"{i+1}. {ex}" for i, ex in enumerate(examples_to_classify))
319
+
320
+
321
+ label_prompt_template = PromptTemplate(
322
+ input_variables=["classification_type", "labels", "few_shot_examples", "examples", "user_prompt"],
323
+ template=(
324
+ "You are a professional {classification_type} expert. Classify the following examples using these labels: {labels}.\n"
325
+ "Instructions:\n"
326
+ "- Return ONLY the numbered example followed by its classification\n"
327
+ "- Use the format: 'Example text. Label: [label]'\n"
328
+ "- Do not provide explanations or justifications\n"
329
+ "{user_prompt}\n\n"
330
+ "Few-shot examples:\n{few_shot_examples}\n\n"
331
+ "Examples to classify:\n{examples}\n\n"
332
+ "Output:\n"
333
+ )
334
+ )
335
+ col1, col2 = st.columns(2)
336
+ with col1:
337
+ if st.button("🏷️ Label Data"):
338
+ if examples_to_classify:
339
+ with st.spinner("Labeling data..."):
340
+ system_prompt = label_prompt_template.format(
341
+ classification_type=classification_type,
342
+ labels=", ".join(labels),
343
+ few_shot_examples=few_shot_text,
344
+ examples=examples_text,
345
+ user_prompt=user_prompt
346
+ )
347
+ try:
348
+ stream = client.chat.completions.create(
349
+ model=selected_model,
350
+ messages=[{"role": "system", "content": system_prompt}],
351
+ temperature=temperature,
352
+ stream=True,
353
+ max_tokens=3000,
354
+ )
355
+ response = st.write_stream(stream)
356
+ st.session_state.messages.append({"role": "assistant", "content": response})
357
+ except Exception as e:
358
+ st.error("An error occurred during labeling.")
359
+ st.error(f"Details: {e}")
360
+ else:
361
+ st.warning("Please enter at least one example to classify.")
362
+
363
+ with col2:
364
+ if st.button("🔄 Relabel"):
365
+ if examples_to_classify:
366
+ st.session_state.messages = st.session_state.messages[:-1] if st.session_state.messages else []
367
+ with st.spinner("Relabeling data..."):
368
+ system_prompt = label_prompt_template.format(
369
+ classification_type=classification_type,
370
+ labels=", ".join(labels),
371
+ few_shot_examples=few_shot_text,
372
+ examples=examples_text,
373
+ user_prompt=user_prompt
374
+ )
375
+ try:
376
+ stream = client.chat.completions.create(
377
+ model=selected_model,
378
+ messages=[{"role": "system", "content": system_prompt}],
379
+ temperature=temperature,
380
+ stream=True,
381
+ max_tokens=3000,
382
+ )
383
+ response = st.write_stream(stream)
384
+ st.session_state.messages.append({"role": "assistant", "content": response})
385
+ except Exception as e:
386
+ st.error("An error occurred during relabeling.")
387
+ st.error(f"Details: {e}")
388
+ else:
389
+ st.warning("Please enter at least one example to classify.")
390
+
391
+ if st.session_state.messages:
392
+ st.markdown("### Output:")
393
+ for message in st.session_state.messages[-1:]:
394
+ st.markdown(message["content"])