nsourlos commited on
Commit
88cec06
·
1 Parent(s): b2bc208

'final_submit_message_correction'

Browse files
Files changed (3) hide show
  1. app.py +132 -0
  2. data.csv +5 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from huggingface_hub import HfFileSystem
6
+
7
+ REPO_ID = "nsourlos/draco_streamlit"
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
+
10
+ data_path='data.csv'
11
+
12
+ # Load the CSV file
13
+ def load_data(file):
14
+ df = pd.read_csv(file, index_col='id')
15
+ return df
16
+
17
+ # Save the CSV file
18
+ def save_data(df, filename):
19
+ df.to_csv(filename)
20
+
21
+ # Function to calculate accuracy for each unique text attribute
22
+ def calculate_accuracy(df):
23
+ accuracy_dict = {}
24
+ grouped = df.groupby('text')['label']
25
+ for text, labels in grouped:
26
+ accuracy = labels.mean() # Accuracy is the mean of the label values
27
+ accuracy_dict[text] = accuracy
28
+ return accuracy_dict
29
+
30
+ # Initialize session state variables
31
+ if 'data' not in st.session_state:
32
+ st.session_state.data = None
33
+
34
+ if 'new_rows' not in st.session_state:
35
+ st.session_state.new_rows = []
36
+
37
+ if 'file_path' not in st.session_state:
38
+ st.session_state.file_path = None
39
+
40
+ if 'add_row_clicked' not in st.session_state:
41
+ st.session_state.add_row_clicked = False
42
+
43
+ if 'rerun_count' not in st.session_state:
44
+ st.session_state.rerun_count = 0
45
+
46
+ if 'finished' not in st.session_state:
47
+ st.session_state.finished = False
48
+
49
+
50
+ # Function to add new row
51
+ def add_row(new_text, new_label):
52
+ new_id = st.session_state['data'].index.max() + 1 if not st.session_state['data'].empty else 0
53
+ new_row = {'id': new_id, 'text': new_text, 'label': new_label, 'checked': False}
54
+
55
+ st.session_state.new_rows.append(new_row)
56
+ updated_data=pd.concat([st.session_state.data, pd.DataFrame([new_row]).set_index('id')])
57
+ file_path=st.session_state.file_path
58
+
59
+ save_data(updated_data, file_path)
60
+ st.session_state.data=load_data(file_path)
61
+ st.session_state.add_row_clicked = False # Reset the add row state
62
+ st.session_state.rerun_count += 1
63
+
64
+ st.rerun()
65
+
66
+ # Streamlit app
67
+ st.title("Interactive DataFrame Editor")
68
+
69
+ # uploaded_file = st.file_uploader("Upload your CSV file", type="csv")
70
+ uploaded_file = data_path#'data.csv'
71
+
72
+ if uploaded_file is not None:
73
+
74
+ st.session_state.file_path = uploaded_file#.name
75
+ if st.session_state.rerun_count==0:
76
+ st.session_state.data = load_data(uploaded_file)
77
+
78
+ file_loaded=uploaded_file#.name
79
+
80
+ st.subheader("DataFrame")
81
+ if st.session_state.data is not None:
82
+
83
+ # Display non-editable columns
84
+ edited_data = st.data_editor(st.session_state.data)
85
+
86
+ if edited_data is not None:
87
+ st.session_state.data = edited_data
88
+ save_data(st.session_state.data, st.session_state.file_path)
89
+
90
+
91
+ if st.button("Add Row"):
92
+ st.session_state.add_row_clicked = True
93
+
94
+ if st.session_state.add_row_clicked:
95
+ # Inputs for adding new row
96
+ new_text = st.text_input("Enter model name for new row:")
97
+ new_label = st.selectbox("Select label for new row:", options=[0, 1])
98
+
99
+ if st.button("Confirm Add Row"):
100
+ add_row(new_text, new_label)
101
+
102
+
103
+ # Calculate accuracy
104
+ accuracy_dict = calculate_accuracy(st.session_state.data)
105
+
106
+ # Create scatter plot
107
+ texts = list(accuracy_dict.keys())
108
+ accuracies = list(accuracy_dict.values())
109
+
110
+ fig, ax = plt.subplots(figsize=(10, 4))
111
+ ax.scatter(texts, accuracies)
112
+ ax.set_xlabel('Text')
113
+ ax.set_ylabel('Accuracy')
114
+ ax.set_title('Accuracy of Labels for Each Text Attribute')
115
+ plt.xticks(rotation=90) # Rotate x-axis labels for better readability
116
+
117
+ st.subheader("Leaderboard")
118
+
119
+ st.pyplot(fig)
120
+
121
+ # Button to finish and reset session state
122
+ if st.button('Finish'):
123
+ st.success('Saving.... Space will restart soon....')
124
+ st.session_state.finished = True
125
+
126
+ fs = HfFileSystem(token=HF_TOKEN.replace("\"",""))
127
+
128
+ with fs.open('spaces/nsourlos/draco_streamlit/data.csv', 'w') as f:
129
+ f.write(st.session_state.data.to_csv())
130
+
131
+ else:
132
+ st.write("Please upload a CSV file to get started.")
data.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ id,text,label,checked
2
+ 1,a,1,TRUE
3
+ 2,a,0,FALSE
4
+ 3,a,1,TRUE
5
+ 4,gpt,1,FALSE
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pandas==2.2.2
2
+ tensorflow==2.16.2
3
+ tf-keras==2.16.0
4
+ torch==2.3.1
5
+ torchvision==0.18.1
6
+ torchaudio==2.3.1
7
+ transformers==4.42.3
8
+ streamlit==1.36.0