yaleh commited on
Commit
97ecca5
·
1 Parent(s): 6eaed59

Add more UI components to trigger and to show the intermediate steps.

Browse files
Files changed (1) hide show
  1. demo/sample_generator.ipynb +83 -30
demo/sample_generator.ipynb CHANGED
@@ -109,7 +109,7 @@
109
  "\"\"\")\n",
110
  "]\n",
111
  "\n",
112
- "EXAMPLES_PROMPT = [\n",
113
  " (\"system\", \"\"\"Given the task type description, and input/output example(s), generate {generating_batch_size}\n",
114
  "new input/output examples for this task type.\n",
115
  "\n",
@@ -150,7 +150,7 @@
150
  " self.input_analysis_prompt = ChatPromptTemplate.from_messages(INPUT_ANALYSIS_PROMPT)\n",
151
  " self.briefs_prompt = ChatPromptTemplate.from_messages(BRIEFS_PROMPT)\n",
152
  " self.examples_from_briefs_prompt = ChatPromptTemplate.from_messages(EXAMPLES_FROM_BRIEFS_PROMPT)\n",
153
- " self.examples_prompt = ChatPromptTemplate.from_messages(EXAMPLES_PROMPT)\n",
154
  "\n",
155
  " json_model = model.bind(response_format={\"type\": \"json_object\"})\n",
156
  "\n",
@@ -161,27 +161,34 @@
161
  " self.input_analysis_chain = self.input_analysis_prompt | model | output_parser\n",
162
  " self.briefs_chain = self.briefs_prompt | model | output_parser\n",
163
  " self.examples_from_briefs_chain = self.examples_from_briefs_prompt | json_model | json_parse\n",
164
- " self.examples_chain = self.examples_prompt | json_model | json_parse\n",
 
 
 
165
  "\n",
166
  " self.chain = (\n",
167
- " RunnablePassthrough.assign(raw_example = lambda x: json.dumps(x[\"example\"], ensure_ascii=False))\n",
 
168
  " | RunnablePassthrough.assign(description = self.description_chain)\n",
169
  " | {\n",
170
  " \"description\": lambda x: x[\"description\"],\n",
171
  " \"examples_from_briefs\": RunnablePassthrough.assign(input_analysis = self.input_analysis_chain)\n",
172
  " | RunnablePassthrough.assign(new_example_briefs = self.briefs_chain) \n",
173
  " | RunnablePassthrough.assign(examples = self.examples_from_briefs_chain | (lambda x: x[\"examples\"])),\n",
174
- " \"examples\": self.examples_chain\n",
175
  " }\n",
176
  " | RunnablePassthrough.assign(\n",
177
  " additional_examples=lambda x: (\n",
178
  " list(x[\"examples_from_briefs\"][\"examples\"])\n",
179
- " + list(x[\"examples\"][\"examples\"])\n",
180
  " )\n",
181
  " )\n",
182
  " )\n",
183
  "\n",
184
- " def process(self, input_str, generating_batch_size=3):\n",
 
 
 
185
  " try:\n",
186
  " try:\n",
187
  " example_dict = json.loads(input_str)\n",
@@ -202,15 +209,26 @@
202
  " # Move the original content to a key named 'example'\n",
203
  " input_dict = {\"example\": example_dict, \"generating_batch_size\": generating_batch_size}\n",
204
  "\n",
205
- " # Invoke the chain with the parsed input dictionary\n",
206
- " result = self.chain.invoke(input_dict)\n",
207
- " return result\n",
208
  "\n",
209
  " except Exception as e:\n",
210
  " raise RuntimeError(f\"An error occurred during processing: {str(e)}\")\n",
211
  "\n",
212
- " def generate_description(self, input_str):\n",
213
- " return self.description_chain.invoke(input_str)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  "\n",
215
  " def analyze_input(self, description):\n",
216
  " return self.input_analysis_chain.invoke(description)\n",
@@ -222,16 +240,25 @@
222
  " \"generating_batch_size\": generating_batch_size\n",
223
  " })\n",
224
  "\n",
225
- " def generate_examples_from_briefs(self, description, new_example_briefs, raw_example, generating_batch_size):\n",
226
- " return self.examples_from_briefs_chain.invoke({\n",
 
 
 
 
 
 
 
 
 
227
  " \"description\": description,\n",
228
  " \"new_example_briefs\": new_example_briefs,\n",
229
- " \"raw_example\": raw_example,\n",
230
  " \"generating_batch_size\": generating_batch_size\n",
231
  " })\n",
232
  "\n",
233
- " def generate_examples(self, description, raw_example, generating_batch_size):\n",
234
- " return self.examples_chain.invoke({\n",
235
  " \"description\": description,\n",
236
  " \"raw_example\": raw_example,\n",
237
  " \"generating_batch_size\": generating_batch_size\n",
@@ -252,9 +279,12 @@
252
  " generator = TaskDescriptionGenerator(model)\n",
253
  " result = generator.process(input_json, generating_batch_size)\n",
254
  " description = result[\"description\"]\n",
 
255
  " input_analysis = result[\"examples_from_briefs\"][\"input_analysis\"]\n",
 
 
256
  " examples = [[example[\"input\"], example[\"output\"]] for example in result[\"additional_examples\"]]\n",
257
- " return description, input_analysis, examples\n",
258
  " except Exception as e:\n",
259
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
260
  " \n",
@@ -285,20 +315,22 @@
285
  " except Exception as e:\n",
286
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
287
  " \n",
288
- "def generate_examples_from_briefs(description, new_example_briefs, raw_example, generating_batch_size, model_name, temperature):\n",
289
  " try:\n",
290
  " model = ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)\n",
291
  " generator = TaskDescriptionGenerator(model)\n",
292
- " examples = generator.generate_examples_from_briefs(description, new_example_briefs, raw_example, generating_batch_size)\n",
 
293
  " return examples\n",
294
  " except Exception as e:\n",
295
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
296
  " \n",
297
- "def generate_examples(description, raw_example, generating_batch_size, model_name, temperature):\n",
298
  " try:\n",
299
  " model = ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)\n",
300
  " generator = TaskDescriptionGenerator(model)\n",
301
- " examples = generator.generate_examples(description, raw_example, generating_batch_size)\n",
 
302
  " return examples\n",
303
  " except Exception as e:\n",
304
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
@@ -330,7 +362,10 @@
330
  "\n",
331
  " with gr.Column(scale=1): # Outputs column\n",
332
  " description_output = gr.Textbox(label=\"Description\", lines=5, show_copy_button=True)\n",
333
- " analyze_input_button = gr.Button(\"Analyze Input\", variant=\"secondary\")\n",
 
 
 
334
  " input_analysis_output = gr.Textbox(label=\"Input Analysis\", lines=5, show_copy_button=True)\n",
335
  " generate_briefs_button = gr.Button(\"Generate Briefs\", variant=\"secondary\")\n",
336
  " example_briefs_output = gr.Textbox(label=\"Example Briefs\", lines=5, show_copy_button=True)\n",
@@ -346,13 +381,7 @@
346
  " submit_button.click(\n",
347
  " fn=process_json,\n",
348
  " inputs=[input_json, model_name, generating_batch_size, temperature],\n",
349
- " outputs=[description_output, input_analysis_output, examples_output]\n",
350
- " )\n",
351
- "\n",
352
- " analyze_input_button.click(\n",
353
- " fn=analyze_input,\n",
354
- " inputs=[description_output, model_name, temperature],\n",
355
- " outputs=[input_analysis_output]\n",
356
  " )\n",
357
  "\n",
358
  " generate_description_button.click(\n",
@@ -361,6 +390,18 @@
361
  " outputs=[description_output]\n",
362
  " )\n",
363
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
364
  " generate_briefs_button.click(\n",
365
  " fn=generate_briefs,\n",
366
  " inputs=[description_output, input_analysis_output, generating_batch_size, model_name, temperature],\n",
@@ -373,6 +414,18 @@
373
  " outputs=[examples_from_briefs_output]\n",
374
  " )\n",
375
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
376
  " examples_output.select(\n",
377
  " fn=format_selected_example,\n",
378
  " inputs=[examples_output],\n",
 
109
  "\"\"\")\n",
110
  "]\n",
111
  "\n",
112
+ "EXAMPLES_DIRECTLY_PROMPT = [\n",
113
  " (\"system\", \"\"\"Given the task type description, and input/output example(s), generate {generating_batch_size}\n",
114
  "new input/output examples for this task type.\n",
115
  "\n",
 
150
  " self.input_analysis_prompt = ChatPromptTemplate.from_messages(INPUT_ANALYSIS_PROMPT)\n",
151
  " self.briefs_prompt = ChatPromptTemplate.from_messages(BRIEFS_PROMPT)\n",
152
  " self.examples_from_briefs_prompt = ChatPromptTemplate.from_messages(EXAMPLES_FROM_BRIEFS_PROMPT)\n",
153
+ " self.examples_directly_prompt = ChatPromptTemplate.from_messages(EXAMPLES_DIRECTLY_PROMPT)\n",
154
  "\n",
155
  " json_model = model.bind(response_format={\"type\": \"json_object\"})\n",
156
  "\n",
 
161
  " self.input_analysis_chain = self.input_analysis_prompt | model | output_parser\n",
162
  " self.briefs_chain = self.briefs_prompt | model | output_parser\n",
163
  " self.examples_from_briefs_chain = self.examples_from_briefs_prompt | json_model | json_parse\n",
164
+ " self.examples_directly_chain = self.examples_directly_prompt | json_model | json_parse\n",
165
+ "\n",
166
+ " # New sub-chain for loading and validating input\n",
167
+ " self.input_loader = RunnableLambda(self.load_and_validate_input)\n",
168
  "\n",
169
  " self.chain = (\n",
170
+ " self.input_loader\n",
171
+ " | RunnablePassthrough.assign(raw_example = lambda x: json.dumps(x[\"example\"], ensure_ascii=False))\n",
172
  " | RunnablePassthrough.assign(description = self.description_chain)\n",
173
  " | {\n",
174
  " \"description\": lambda x: x[\"description\"],\n",
175
  " \"examples_from_briefs\": RunnablePassthrough.assign(input_analysis = self.input_analysis_chain)\n",
176
  " | RunnablePassthrough.assign(new_example_briefs = self.briefs_chain) \n",
177
  " | RunnablePassthrough.assign(examples = self.examples_from_briefs_chain | (lambda x: x[\"examples\"])),\n",
178
+ " \"examples_directly\": self.examples_directly_chain\n",
179
  " }\n",
180
  " | RunnablePassthrough.assign(\n",
181
  " additional_examples=lambda x: (\n",
182
  " list(x[\"examples_from_briefs\"][\"examples\"])\n",
183
+ " + list(x[\"examples_directly\"][\"examples\"])\n",
184
  " )\n",
185
  " )\n",
186
  " )\n",
187
  "\n",
188
+ " def load_and_validate_input(self, input_dict):\n",
189
+ " input_str = input_dict[\"input_str\"]\n",
190
+ " generating_batch_size = input_dict[\"generating_batch_size\"]\n",
191
+ "\n",
192
  " try:\n",
193
  " try:\n",
194
  " example_dict = json.loads(input_str)\n",
 
209
  " # Move the original content to a key named 'example'\n",
210
  " input_dict = {\"example\": example_dict, \"generating_batch_size\": generating_batch_size}\n",
211
  "\n",
212
+ " return input_dict\n",
 
 
213
  "\n",
214
  " except Exception as e:\n",
215
  " raise RuntimeError(f\"An error occurred during processing: {str(e)}\")\n",
216
  "\n",
217
+ " def process(self, input_str, generating_batch_size=3):\n",
218
+ " input_dict = {\"input_str\": input_str, \"generating_batch_size\": generating_batch_size}\n",
219
+ " result = self.chain.invoke(input_dict)\n",
220
+ " return result\n",
221
+ "\n",
222
+ " def generate_description(self, input_str, generating_batch_size=3):\n",
223
+ " chain = (\n",
224
+ " self.input_loader \n",
225
+ " | RunnablePassthrough.assign(raw_example = lambda x: json.dumps(x[\"example\"], ensure_ascii=False))\n",
226
+ " | self.description_chain\n",
227
+ " )\n",
228
+ " return chain.invoke({\n",
229
+ " \"input_str\": input_str,\n",
230
+ " \"generating_batch_size\": generating_batch_size\n",
231
+ " })\n",
232
  "\n",
233
  " def analyze_input(self, description):\n",
234
  " return self.input_analysis_chain.invoke(description)\n",
 
240
  " \"generating_batch_size\": generating_batch_size\n",
241
  " })\n",
242
  "\n",
243
+ " def generate_examples_from_briefs(self, description, new_example_briefs, input_str, generating_batch_size=3):\n",
244
+ " chain = (\n",
245
+ " self.input_loader\n",
246
+ " | RunnablePassthrough.assign(\n",
247
+ " raw_example = lambda x: json.dumps(x[\"example\"], ensure_ascii=False),\n",
248
+ " description = lambda x: description,\n",
249
+ " new_example_briefs = lambda x: new_example_briefs\n",
250
+ " )\n",
251
+ " | self.examples_from_briefs_chain\n",
252
+ " )\n",
253
+ " return chain.invoke({\n",
254
  " \"description\": description,\n",
255
  " \"new_example_briefs\": new_example_briefs,\n",
256
+ " \"input_str\": input_str,\n",
257
  " \"generating_batch_size\": generating_batch_size\n",
258
  " })\n",
259
  "\n",
260
+ " def generate_examples_directly(self, description, raw_example, generating_batch_size):\n",
261
+ " return self.examples_directly_chain.invoke({\n",
262
  " \"description\": description,\n",
263
  " \"raw_example\": raw_example,\n",
264
  " \"generating_batch_size\": generating_batch_size\n",
 
279
  " generator = TaskDescriptionGenerator(model)\n",
280
  " result = generator.process(input_json, generating_batch_size)\n",
281
  " description = result[\"description\"]\n",
282
+ " examples_directly = [[example[\"input\"], example[\"output\"]] for example in result[\"examples_directly\"][\"examples\"]]\n",
283
  " input_analysis = result[\"examples_from_briefs\"][\"input_analysis\"]\n",
284
+ " new_example_briefs = result[\"examples_from_briefs\"][\"new_example_briefs\"]\n",
285
+ " examples_from_briefs = [[example[\"input\"], example[\"output\"]] for example in result[\"examples_from_briefs\"][\"examples\"]]\n",
286
  " examples = [[example[\"input\"], example[\"output\"]] for example in result[\"additional_examples\"]]\n",
287
+ " return description, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples\n",
288
  " except Exception as e:\n",
289
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
290
  " \n",
 
315
  " except Exception as e:\n",
316
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
317
  " \n",
318
+ "def generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size, model_name, temperature):\n",
319
  " try:\n",
320
  " model = ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)\n",
321
  " generator = TaskDescriptionGenerator(model)\n",
322
+ " result = generator.generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size)\n",
323
+ " examples = [[example[\"input\"], example[\"output\"]] for example in result[\"examples\"]]\n",
324
  " return examples\n",
325
  " except Exception as e:\n",
326
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
327
  " \n",
328
+ "def generate_examples_directly(description, raw_example, generating_batch_size, model_name, temperature):\n",
329
  " try:\n",
330
  " model = ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)\n",
331
  " generator = TaskDescriptionGenerator(model)\n",
332
+ " result = generator.generate_examples_directly(description, raw_example, generating_batch_size)\n",
333
+ " examples = [[example[\"input\"], example[\"output\"]] for example in result[\"examples\"]]\n",
334
  " return examples\n",
335
  " except Exception as e:\n",
336
  " raise gr.Error(f\"An error occurred: {str(e)}\")\n",
 
362
  "\n",
363
  " with gr.Column(scale=1): # Outputs column\n",
364
  " description_output = gr.Textbox(label=\"Description\", lines=5, show_copy_button=True)\n",
365
+ " with gr.Row():\n",
366
+ " generate_examples_directly_button = gr.Button(\"Generate Examples Directly\", variant=\"secondary\")\n",
367
+ " analyze_input_button = gr.Button(\"Analyze Input\", variant=\"secondary\")\n",
368
+ " examples_directly_output = gr.DataFrame(label=\"Examples Directly\", headers=[\"Input\", \"Output\"], interactive=False)\n",
369
  " input_analysis_output = gr.Textbox(label=\"Input Analysis\", lines=5, show_copy_button=True)\n",
370
  " generate_briefs_button = gr.Button(\"Generate Briefs\", variant=\"secondary\")\n",
371
  " example_briefs_output = gr.Textbox(label=\"Example Briefs\", lines=5, show_copy_button=True)\n",
 
381
  " submit_button.click(\n",
382
  " fn=process_json,\n",
383
  " inputs=[input_json, model_name, generating_batch_size, temperature],\n",
384
+ " outputs=[description_output, examples_directly_output, input_analysis_output, example_briefs_output, examples_from_briefs_output, examples_output]\n",
 
 
 
 
 
 
385
  " )\n",
386
  "\n",
387
  " generate_description_button.click(\n",
 
390
  " outputs=[description_output]\n",
391
  " )\n",
392
  "\n",
393
+ " generate_examples_directly_button.click(\n",
394
+ " fn=generate_examples_directly,\n",
395
+ " inputs=[description_output, input_json, generating_batch_size, model_name, temperature],\n",
396
+ " outputs=[examples_directly_output]\n",
397
+ " )\n",
398
+ "\n",
399
+ " analyze_input_button.click(\n",
400
+ " fn=analyze_input,\n",
401
+ " inputs=[description_output, model_name, temperature],\n",
402
+ " outputs=[input_analysis_output]\n",
403
+ " )\n",
404
+ "\n",
405
  " generate_briefs_button.click(\n",
406
  " fn=generate_briefs,\n",
407
  " inputs=[description_output, input_analysis_output, generating_batch_size, model_name, temperature],\n",
 
414
  " outputs=[examples_from_briefs_output]\n",
415
  " )\n",
416
  "\n",
417
+ " examples_directly_output.select(\n",
418
+ " fn=format_selected_example,\n",
419
+ " inputs=[examples_directly_output],\n",
420
+ " outputs=[new_example_json]\n",
421
+ " )\n",
422
+ "\n",
423
+ " examples_from_briefs_output.select(\n",
424
+ " fn=format_selected_example,\n",
425
+ " inputs=[examples_from_briefs_output],\n",
426
+ " outputs=[new_example_json]\n",
427
+ " )\n",
428
+ "\n",
429
  " examples_output.select(\n",
430
  " fn=format_selected_example,\n",
431
  " inputs=[examples_output],\n",