joshuadunlop commited on
Commit
f90d0d6
·
1 Parent(s): f3eae7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -57
app.py CHANGED
@@ -28,63 +28,6 @@ followup_message = st.sidebar.text_area("Edit Message:")
28
  row_count = st.session_state.get("row_count", 1)
29
  generate_all_edits = st.sidebar.button("Generate All Edits")
30
 
31
- class WorkerThread(threading.Thread):
32
- def __init__(self, jobs, results):
33
- super().__init__()
34
- self.jobs = jobs
35
- self.results = results
36
-
37
- def run(self):
38
- while True:
39
- job = self.jobs.get()
40
- if job is None:
41
- break
42
- i, messages = job
43
- result = generate_response(i, messages)
44
- self.results.put(result)
45
-
46
- if generate_all_edits:
47
- # Creating a list of messages
48
- messages = [st.session_state.get(f"message{i}", "") for i in range(row_count)]
49
-
50
- jobs = Queue()
51
- results = Queue()
52
-
53
- # Create workers
54
- workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)]
55
-
56
- # Start all workers
57
- for worker in workers:
58
- worker.start()
59
-
60
- # Put all the jobs into the queue
61
- for i, message in enumerate(messages):
62
- jobs.put((i, [
63
- {"role": "system", "content": system_message},
64
- {"role": "user", "content": message},
65
- {"role": "user", "content": followup_message}
66
- ]))
67
-
68
- # Put a None for each worker to indicate the end of jobs
69
- for _ in range(num_concurrent_calls):
70
- jobs.put(None)
71
-
72
- # Wait for all of the tasks to finish
73
- for worker in workers:
74
- worker.join()
75
-
76
- # Collect all the results
77
- while not results.empty():
78
- i, response, prompt_tokens, response_tokens, word_count, error_message = results.get()
79
- if error_message is not None:
80
- st.write(f"Error on row {i}: {error_message}")
81
- st.session_state[f"followup_response{i}"] = response
82
- st.session_state[f"prompt_tokens{i}"] = prompt_tokens
83
- st.session_state[f"response_tokens{i}"] = response_tokens
84
- st.session_state[f"word_count{i}"] = word_count
85
-
86
- # ... (rest of the code remains unchanged)
87
-
88
  if add_row:
89
  row_count += 1
90
  st.session_state.row_count = row_count
@@ -190,6 +133,48 @@ if generate_all:
190
  st.session_state[f"response_tokens{i}"] = response_tokens
191
  st.session_state[f"word_count{i}"] = word_count
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  for i in range(row_count):
194
  if show_notes:
195
  st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", ""))
 
28
  row_count = st.session_state.get("row_count", 1)
29
  generate_all_edits = st.sidebar.button("Generate All Edits")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if add_row:
32
  row_count += 1
33
  st.session_state.row_count = row_count
 
133
  st.session_state[f"response_tokens{i}"] = response_tokens
134
  st.session_state[f"word_count{i}"] = word_count
135
 
136
+ if generate_all_edits:
137
+ # Creating a list of messages
138
+ messages = [st.session_state.get(f"message{i}", "") for i in range(row_count)]
139
+
140
+ jobs = Queue()
141
+ results = Queue()
142
+
143
+ # Create workers
144
+ workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)]
145
+
146
+ # Start all workers
147
+ for worker in workers:
148
+ worker.start()
149
+
150
+ # Put all the jobs into the queue
151
+ for i, message in enumerate(messages):
152
+ jobs.put((i, [
153
+ {"role": "system", "content": system_message},
154
+ {"role": "user", "content": message},
155
+ {"role": "user", "content": followup_message}
156
+ ]))
157
+
158
+ # Put a None for each worker to indicate the end of jobs
159
+ for _ in range(num_concurrent_calls):
160
+ jobs.put(None)
161
+
162
+ # Wait for all of the tasks to finish
163
+ for worker in workers:
164
+ worker.join()
165
+
166
+ # Collect all the results
167
+ while not results.empty():
168
+ i, response, prompt_tokens, response_tokens, word_count, error_message = results.get()
169
+ if error_message is not None:
170
+ st.write(f"Error on row {i}: {error_message}")
171
+ st.session_state[f"followup_response{i}"] = response
172
+ st.session_state[f"prompt_tokens{i}"] = prompt_tokens
173
+ st.session_state[f"response_tokens{i}"] = response_tokens
174
+ st.session_state[f"word_count{i}"] = word_count
175
+
176
+ # ... (rest of the code remains unchanged)
177
+
178
  for i in range(row_count):
179
  if show_notes:
180
  st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", ""))