yaleh commited on
Commit
f9ab7f7
·
1 Parent(s): 2b661e2

`Apply Changes` works.

Browse files
app/streamlit_sample_generator.py CHANGED
@@ -12,6 +12,7 @@ def process_json(input_json, model_name, generating_batch_size, temperature):
12
  generator = TaskDescriptionGenerator(model)
13
  result = generator.process(input_json, generating_batch_size)
14
  description = result["description"]
 
15
  examples_directly = [[example["input"], example["output"]]
16
  for example in result["examples_directly"]["examples"]]
17
  input_analysis = result["examples_from_briefs"]["input_analysis"]
@@ -20,10 +21,10 @@ def process_json(input_json, model_name, generating_batch_size, temperature):
20
  for example in result["examples_from_briefs"]["examples"]]
21
  examples = [[example["input"], example["output"]]
22
  for example in result["additional_examples"]]
23
- return description, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples
24
  except Exception as e:
25
  st.warning(f"An error occurred: {str(e)}. Returning default values.")
26
- return "", [], "", [], [], []
27
 
28
 
29
  def generate_description_only(input_json, model_name, temperature):
@@ -31,10 +32,13 @@ def generate_description_only(input_json, model_name, temperature):
31
  model = ChatOpenAI(
32
  model=model_name, temperature=temperature, max_retries=3)
33
  generator = TaskDescriptionGenerator(model)
34
- description = generator.generate_description(input_json)
35
- return description
 
 
36
  except Exception as e:
37
- st.error(f"An error occurred: {str(e)}")
 
38
 
39
 
40
  def analyze_input(description, model_name, temperature):
@@ -45,7 +49,8 @@ def analyze_input(description, model_name, temperature):
45
  input_analysis = generator.analyze_input(description)
46
  return input_analysis
47
  except Exception as e:
48
- st.error(f"An error occurred: {str(e)}")
 
49
 
50
 
51
  def generate_briefs(description, input_analysis, generating_batch_size, model_name, temperature):
@@ -57,7 +62,8 @@ def generate_briefs(description, input_analysis, generating_batch_size, model_na
57
  description, input_analysis, generating_batch_size)
58
  return briefs
59
  except Exception as e:
60
- st.error(f"An error occurred: {str(e)}")
 
61
 
62
 
63
  def generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size, model_name, temperature):
@@ -71,7 +77,8 @@ def generate_examples_from_briefs(description, new_example_briefs, input_str, ge
71
  for example in result["examples"]]
72
  return examples
73
  except Exception as e:
74
- st.error(f"An error occurred: {str(e)}")
 
75
 
76
 
77
  def generate_examples_directly(description, raw_example, generating_batch_size, model_name, temperature):
@@ -85,7 +92,8 @@ def generate_examples_directly(description, raw_example, generating_batch_size,
85
  for example in result["examples"]]
86
  return examples
87
  except Exception as e:
88
- st.error(f"An error occurred: {str(e)}")
 
89
 
90
 
91
  def example_directly_selected():
@@ -142,6 +150,9 @@ if 'input_data' not in st.session_state:
142
  if 'description_output_text' not in st.session_state:
143
  st.session_state.description_output_text = ''
144
 
 
 
 
145
  if 'input_analysis_output_text' not in st.session_state:
146
  st.session_state.input_analysis_output_text = ''
147
 
@@ -169,8 +180,9 @@ if 'selected_example' not in st.session_state:
169
 
170
  def update_description_output_text():
171
  input_json = package_input_data()
172
- st.session_state.description_output_text = generate_description_only(
173
- input_json, model_name, temperature)
 
174
 
175
 
176
  def update_input_analysis_output_text():
@@ -203,8 +215,9 @@ def generate_examples_dataframe():
203
  input_json = package_input_data()
204
  result = process_json(input_json, model_name,
205
  generating_batch_size, temperature)
206
- description, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples = result
207
  st.session_state.description_output_text = description
 
208
  st.session_state.examples_directly_dataframe = pd.DataFrame(
209
  examples_directly, columns=["Input", "Output"])
210
  st.session_state.input_analysis_output_text = input_analysis
@@ -239,6 +252,12 @@ def import_input_data_from_json():
239
  except Exception as e:
240
  st.error(f"Failed to import JSON: {str(e)}")
241
 
 
 
 
 
 
 
242
 
243
  # Streamlit UI
244
  st.title("LLM Task Example Generator")
@@ -288,6 +307,13 @@ with st.expander("Description and Analysis"):
288
  description_output = st.text_area(
289
  "Description", value=st.session_state.description_output_text, height=100)
290
 
 
 
 
 
 
 
 
291
  col3, col4 = st.columns(2)
292
  with col3:
293
  generate_examples_directly_button = st.button(
@@ -327,3 +353,4 @@ def show_sidebar():
327
  st.button("Append to Input Data", on_click=append_selected_to_input_data)
328
 
329
  show_sidebar()
 
 
12
  generator = TaskDescriptionGenerator(model)
13
  result = generator.process(input_json, generating_batch_size)
14
  description = result["description"]
15
+ suggestions = result["suggestions"]
16
  examples_directly = [[example["input"], example["output"]]
17
  for example in result["examples_directly"]["examples"]]
18
  input_analysis = result["examples_from_briefs"]["input_analysis"]
 
21
  for example in result["examples_from_briefs"]["examples"]]
22
  examples = [[example["input"], example["output"]]
23
  for example in result["additional_examples"]]
24
+ return description, suggestions, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples
25
  except Exception as e:
26
  st.warning(f"An error occurred: {str(e)}. Returning default values.")
27
+ return "", [], [], "", [], [], []
28
 
29
 
30
  def generate_description_only(input_json, model_name, temperature):
 
32
  model = ChatOpenAI(
33
  model=model_name, temperature=temperature, max_retries=3)
34
  generator = TaskDescriptionGenerator(model)
35
+ result = generator.generate_description(input_json)
36
+ description = result["description"]
37
+ suggestions = result["suggestions"]
38
+ return description, suggestions
39
  except Exception as e:
40
+ st.warning(f"An error occurred: {str(e)}")
41
+ return "", []
42
 
43
 
44
  def analyze_input(description, model_name, temperature):
 
49
  input_analysis = generator.analyze_input(description)
50
  return input_analysis
51
  except Exception as e:
52
+ st.warning(f"An error occurred: {str(e)}")
53
+ return ""
54
 
55
 
56
  def generate_briefs(description, input_analysis, generating_batch_size, model_name, temperature):
 
62
  description, input_analysis, generating_batch_size)
63
  return briefs
64
  except Exception as e:
65
+ st.warning(f"An error occurred: {str(e)}")
66
+ return ""
67
 
68
 
69
  def generate_examples_from_briefs(description, new_example_briefs, input_str, generating_batch_size, model_name, temperature):
 
77
  for example in result["examples"]]
78
  return examples
79
  except Exception as e:
80
+ st.warning(f"An error occurred: {str(e)}")
81
+ return []
82
 
83
 
84
  def generate_examples_directly(description, raw_example, generating_batch_size, model_name, temperature):
 
92
  for example in result["examples"]]
93
  return examples
94
  except Exception as e:
95
+ st.warning(f"An error occurred: {str(e)}")
96
+ return []
97
 
98
 
99
  def example_directly_selected():
 
150
  if 'description_output_text' not in st.session_state:
151
  st.session_state.description_output_text = ''
152
 
153
+ if 'suggestions' not in st.session_state:
154
+ st.session_state.suggestions = []
155
+
156
  if 'input_analysis_output_text' not in st.session_state:
157
  st.session_state.input_analysis_output_text = ''
158
 
 
180
 
181
  def update_description_output_text():
182
  input_json = package_input_data()
183
+ result = generate_description_only(input_json, model_name, temperature)
184
+ st.session_state.description_output_text = result[0]
185
+ st.session_state.suggestions = result[1]
186
 
187
 
188
  def update_input_analysis_output_text():
 
215
  input_json = package_input_data()
216
  result = process_json(input_json, model_name,
217
  generating_batch_size, temperature)
218
+ description, suggestions, examples_directly, input_analysis, new_example_briefs, examples_from_briefs, examples = result
219
  st.session_state.description_output_text = description
220
+ st.session_state.suggestions = suggestions # Ensure suggestions are stored in session state
221
  st.session_state.examples_directly_dataframe = pd.DataFrame(
222
  examples_directly, columns=["Input", "Output"])
223
  st.session_state.input_analysis_output_text = input_analysis
 
252
  except Exception as e:
253
  st.error(f"Failed to import JSON: {str(e)}")
254
 
255
+ def apply_suggestions():
256
+ result = TaskDescriptionGenerator(
257
+ ChatOpenAI(model=model_name, temperature=temperature, max_retries=3)).update_description(
258
+ package_input_data(), st.session_state.description_output_text, st.session_state.selected_suggestions)
259
+ st.session_state.description_output_text = result["description"]
260
+ st.session_state.suggestions = result["suggestions"]
261
 
262
  # Streamlit UI
263
  st.title("LLM Task Example Generator")
 
307
  description_output = st.text_area(
308
  "Description", value=st.session_state.description_output_text, height=100)
309
 
310
+ # Add multiselect for suggestions
311
+ selected_suggestions = st.multiselect(
312
+ "Suggestions", options=st.session_state.suggestions, key="selected_suggestions")
313
+
314
+ # Add button to apply suggestions
315
+ apply_suggestions_button = st.button("Apply Suggestions", on_click=apply_suggestions)
316
+
317
  col3, col4 = st.columns(2)
318
  with col3:
319
  generate_examples_directly_button = st.button(
 
353
  st.button("Append to Input Data", on_click=append_selected_to_input_data)
354
 
355
  show_sidebar()
356
+
meta_prompt/sample_generator.py CHANGED
@@ -21,6 +21,173 @@ Task Description: [Your description here]
21
  """)
22
  ]
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  INPUT_ANALYSIS_PROMPT = [
25
  ("system", """For the specific task type, analyze the possible task inputs across multiple dimensions.
26
 
@@ -78,16 +245,16 @@ not present in the briefs.
78
 
79
  Format your response as a valid JSON object with a single key 'examples'
80
  containing a JSON array of {generating_batch_size} objects, each with 'input' and 'output' fields.
81
- """),
82
- ("user", """Task Description:
83
 
84
  {description}
85
 
86
- New Example Briefs:
87
 
88
  {new_example_briefs}
89
 
90
- Example(s):
91
 
92
  {raw_example}
93
 
@@ -100,12 +267,12 @@ new input/output examples for this task type.
100
 
101
  Format your response as a valid JSON object with a single key 'examples'
102
  containing a JSON array of {generating_batch_size} objects, each with 'input' and 'output' fields.
103
- """),
104
- ("user", """Task Description:
105
 
106
  {description}
107
 
108
- Example(s):
109
 
110
  {raw_example}
111
 
@@ -116,6 +283,9 @@ Example(s):
116
  class TaskDescriptionGenerator:
117
  def __init__(self, model):
118
  self.description_prompt = ChatPromptTemplate.from_messages(DESCRIPTION_PROMPT)
 
 
 
119
  self.input_analysis_prompt = ChatPromptTemplate.from_messages(INPUT_ANALYSIS_PROMPT)
120
  self.briefs_prompt = ChatPromptTemplate.from_messages(BRIEFS_PROMPT)
121
  self.examples_from_briefs_prompt = ChatPromptTemplate.from_messages(EXAMPLES_FROM_BRIEFS_PROMPT)
@@ -127,6 +297,9 @@ class TaskDescriptionGenerator:
127
  json_parse = JsonOutputParser()
128
 
129
  self.description_chain = self.description_prompt | model | output_parser
 
 
 
130
  self.input_analysis_chain = self.input_analysis_prompt | model | output_parser
131
  self.briefs_chain = self.briefs_prompt | model | output_parser
132
  self.examples_from_briefs_chain = self.examples_from_briefs_prompt | json_model | json_parse
@@ -137,14 +310,18 @@ class TaskDescriptionGenerator:
137
 
138
  self.chain = (
139
  self.input_loader
140
- | RunnablePassthrough.assign(raw_example = lambda x: json.dumps(x["example"], ensure_ascii=False))
141
- | RunnablePassthrough.assign(description = self.description_chain)
142
  | {
143
  "description": lambda x: x["description"],
144
- "examples_from_briefs": RunnablePassthrough.assign(input_analysis = self.input_analysis_chain)
145
- | RunnablePassthrough.assign(new_example_briefs = self.briefs_chain)
146
- | RunnablePassthrough.assign(examples = self.examples_from_briefs_chain | (lambda x: x["examples"])),
147
- "examples_directly": self.examples_directly_chain
 
 
 
 
148
  }
149
  | RunnablePassthrough.assign(
150
  additional_examples=lambda x: (
@@ -154,29 +331,35 @@ class TaskDescriptionGenerator:
154
  )
155
  )
156
 
157
- def load_and_validate_input(self, input_dict):
158
- input_str = input_dict["input_str"]
159
- generating_batch_size = input_dict["generating_batch_size"]
160
-
161
  try:
 
 
162
  try:
163
- example_dict = json.loads(input_str)
164
- except ValueError:
165
- try:
166
- example_dict = yaml.safe_load(input_str)
167
- except yaml.YAMLError as e:
168
- raise ValueError("Invalid input format. Expected a JSON or YAML object.") from e
169
-
170
- # If example_dict is a list, filter out invalid items
171
- if isinstance(example_dict, list):
172
- example_dict = [item for item in example_dict if isinstance(item, dict) and 'input' in item and 'output' in item]
 
 
 
173
 
174
- # If example_dict is not a list, check if it's a valid dict
175
- elif not isinstance(example_dict, dict) or 'input' not in example_dict or 'output' not in example_dict:
176
- raise ValueError("Invalid input format. Expected an object with 'input' and 'output' fields.")
177
 
 
 
178
  # Move the original content to a key named 'example'
179
- input_dict = {"example": example_dict, "generating_batch_size": generating_batch_size}
 
 
180
 
181
  return input_dict
182
 
@@ -191,13 +374,43 @@ class TaskDescriptionGenerator:
191
  def generate_description(self, input_str, generating_batch_size=3):
192
  chain = (
193
  self.input_loader
194
- | RunnablePassthrough.assign(raw_example = lambda x: json.dumps(x["example"], ensure_ascii=False))
195
- | self.description_chain
 
 
 
 
 
 
 
196
  )
197
  return chain.invoke({
198
  "input_str": input_str,
199
  "generating_batch_size": generating_batch_size
200
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def analyze_input(self, description):
203
  return self.input_analysis_chain.invoke(description)
 
21
  """)
22
  ]
23
 
24
+ DESCRIPTION_UPDATING_PROMPT = [
25
+ ("system", """Given the task type description and suggestions, update the task type description according to the suggestions.
26
+
27
+ 1. Input Information:
28
+ - You will receive a task type description and suggestions for updating the description.
29
+ - Carefully read and understand the provided information.
30
+
31
+ 2. Task Analysis:
32
+ - Identify the core elements and characteristics of the task.
33
+ - Consider possible generalization dimensions such as task domain, complexity, input/output format, application scenarios, etc.
34
+
35
+ 3. Update Task Description:
36
+ - Apply the suggestions to update the task description. Don't change anything that is not suggested.
37
+ - Ensure the updated description is clear, specific, and directly related to the task.
38
+
39
+ 4. Output Format:
40
+ - Format your response as follows:
41
+
42
+ Task Description: [Your updated description here]
43
+
44
+ - Output the updated `Task Description` only. Don't output anything else.
45
+
46
+ 5. Completeness Check:
47
+ - Ensure all important aspects of the task description are covered.
48
+ - Check for any missing key information or dimensions.
49
+
50
+ 6. Quantity Requirement:
51
+ - Provide at least 5 specification suggestions across different dimensions.
52
+ """),
53
+ ("user", """***Task Description:***
54
+
55
+ {description}
56
+
57
+ ***Suggestions:***
58
+
59
+ {suggestions}
60
+
61
+ """)
62
+ ]
63
+
64
+
65
+ SPECIFICATION_SUGGESTIONS_PROMPT = [
66
+ ("system", """Based on the given task type description and corresponding input/output examples, list suggestions to specify the task type description in multiple dimensions using JSON format.
67
+
68
+ Please complete this task according to the following requirements:
69
+
70
+ 1. Analyze the given task type description and input/output examples.
71
+
72
+ 2. Identify multiple dimensions to specify the task description, such as:
73
+ - Task purpose
74
+ - Input format requirements
75
+ - Output format requirements
76
+ - Processing steps
77
+ - Evaluation criteria
78
+ - Constraints
79
+ - Special case handling
80
+ - Other relevant dimensions
81
+
82
+ 3. Output Format:
83
+ - Use JSON format to list the suggestions.
84
+ - The JSON structure should contain a top-level array, with each object representing a suggestion for a specific dimension.
85
+
86
+ 4. Suggestion Content:
87
+ - Each suggestion should be clear, specific, and directly related to the task description or input/output examples.
88
+ - Start each suggestion with a verb, such as "Limit the scope of supported tasks to..." Make sure it is an actionable, self-contained, and complete suggestion.
89
+ - Ensure suggestions are compatible with the provided input/output examples.
90
+
91
+ 5. Output Example:
92
+
93
+ ```json
94
+ {{
95
+ "suggestions": [
96
+ {{
97
+ "suggestion": "..."
98
+ }},
99
+ {{
100
+ "suggestion": "..."
101
+ }},
102
+ ...
103
+ ]
104
+ }}
105
+ ```
106
+
107
+ 6. Completeness Check:
108
+ - Ensure all important aspects of the task description are covered.
109
+ - Check for any missing key information or dimensions.
110
+
111
+ 7. Quantity Requirement:
112
+ - Provide at least 5 specification suggestions across different dimensions.
113
+
114
+ Please begin the task directly. After completion, verify that your JSON output meets all requirements and directly addresses the task needs.
115
+
116
+ ***Task Description:***
117
+
118
+ {description}
119
+
120
+ ***Example(s):***
121
+
122
+ {raw_example}
123
+
124
+ """)
125
+ ]
126
+
127
+ GENERALIZATION_SUGGESTIONS_PROMPT = [
128
+ ("system", """Based on a given task type description and corresponding input/output examples, list suggestions for generalizing the task type description across multiple dimensions in JSON format.
129
+
130
+ Please complete this task according to the following requirements:
131
+
132
+ 1. Input Information:
133
+ - You will receive a task type description and corresponding input/output examples.
134
+ - Carefully read and understand the provided information.
135
+
136
+ 2. Task Analysis:
137
+ - Identify the core elements and characteristics of the task.
138
+ - Consider possible generalization dimensions such as task domain, complexity, input/output format, application scenarios, etc.
139
+
140
+ 3. Generate Generalization Suggestions:
141
+ - Based on your analysis, propose generalization suggestions across multiple dimensions.
142
+ - Each suggestion should be a reasonable extension or variation of the original task type.
143
+ - Start each suggestion with a verb, such as "Expand the scope of support to..." Make sure it is an actionable, self-contained, and complete suggestion.
144
+
145
+ 4. Output Format:
146
+ - Use JSON format to list the suggestions.
147
+ - The JSON structure should contain a top-level array, with each object representing a suggestion for a specific dimension.
148
+
149
+ 5. Output Example:
150
+
151
+ ```json
152
+ {{
153
+ "suggestions": [
154
+ {{
155
+ "suggestion": "..."
156
+ }},
157
+ {{
158
+ "suggestion": "..."
159
+ }},
160
+ ...
161
+ ]
162
+ }}
163
+ ```
164
+
165
+ 6. Quality Requirements:
166
+ - Ensure each generalization suggestion is meaningful and feasible.
167
+ - Provide diverse generalization dimensions to cover different aspects of possible extensions.
168
+ - Maintain conciseness and clarity in suggestions.
169
+
170
+ 7. Quantity Requirement:
171
+ - Provide at least 5 generalization suggestions across different dimensions.
172
+
173
+ 8. Notes:
174
+ - Avoid proposing generalizations completely unrelated to the original task.
175
+ - Ensure the JSON format is correct and can be parsed.
176
+
177
+ After completing the task, please check if your output meets all requirements, especially the correctness of the JSON format and the quality of generalization suggestions.
178
+
179
+ ***Task Description:***
180
+
181
+ {description}
182
+
183
+ ***Example(s):***
184
+
185
+ {raw_example}
186
+
187
+
188
+ """)
189
+ ]
190
+
191
  INPUT_ANALYSIS_PROMPT = [
192
  ("system", """For the specific task type, analyze the possible task inputs across multiple dimensions.
193
 
 
245
 
246
  Format your response as a valid JSON object with a single key 'examples'
247
  containing a JSON array of {generating_batch_size} objects, each with 'input' and 'output' fields.
248
+
249
+ ***Task Description:***
250
 
251
  {description}
252
 
253
+ ***New Example Briefs:***
254
 
255
  {new_example_briefs}
256
 
257
+ ***Example(s):***
258
 
259
  {raw_example}
260
 
 
267
 
268
  Format your response as a valid JSON object with a single key 'examples'
269
  containing a JSON array of {generating_batch_size} objects, each with 'input' and 'output' fields.
270
+
271
+ ***Task Description:***
272
 
273
  {description}
274
 
275
+ ***Example(s):***
276
 
277
  {raw_example}
278
 
 
283
  class TaskDescriptionGenerator:
284
  def __init__(self, model):
285
  self.description_prompt = ChatPromptTemplate.from_messages(DESCRIPTION_PROMPT)
286
+ self.description_updating_prompt = ChatPromptTemplate.from_messages(DESCRIPTION_UPDATING_PROMPT)
287
+ self.specification_suggestions_prompt = ChatPromptTemplate.from_messages(SPECIFICATION_SUGGESTIONS_PROMPT)
288
+ self.generalization_suggestions_prompt = ChatPromptTemplate.from_messages(GENERALIZATION_SUGGESTIONS_PROMPT)
289
  self.input_analysis_prompt = ChatPromptTemplate.from_messages(INPUT_ANALYSIS_PROMPT)
290
  self.briefs_prompt = ChatPromptTemplate.from_messages(BRIEFS_PROMPT)
291
  self.examples_from_briefs_prompt = ChatPromptTemplate.from_messages(EXAMPLES_FROM_BRIEFS_PROMPT)
 
297
  json_parse = JsonOutputParser()
298
 
299
  self.description_chain = self.description_prompt | model | output_parser
300
+ self.description_updating_chain = self.description_updating_prompt | model | output_parser
301
+ self.specification_suggestions_chain = self.specification_suggestions_prompt | json_model | json_parse
302
+ self.generalization_suggestions_chain = self.generalization_suggestions_prompt | json_model | json_parse
303
  self.input_analysis_chain = self.input_analysis_prompt | model | output_parser
304
  self.briefs_chain = self.briefs_prompt | model | output_parser
305
  self.examples_from_briefs_chain = self.examples_from_briefs_prompt | json_model | json_parse
 
310
 
311
  self.chain = (
312
  self.input_loader
313
+ | RunnablePassthrough.assign(raw_example=lambda x: json.dumps(x["example"], ensure_ascii=False))
314
+ | RunnablePassthrough.assign(description=self.description_chain)
315
  | {
316
  "description": lambda x: x["description"],
317
+ "examples_from_briefs": RunnablePassthrough.assign(input_analysis=self.input_analysis_chain)
318
+ | RunnablePassthrough.assign(new_example_briefs=self.briefs_chain)
319
+ | RunnablePassthrough.assign(examples=self.examples_from_briefs_chain | (lambda x: x["examples"])),
320
+ "examples_directly": self.examples_directly_chain,
321
+ "suggestions": {
322
+ "specification": self.specification_suggestions_chain,
323
+ "generalization": self.generalization_suggestions_chain
324
+ } | RunnableLambda(lambda x: [item['suggestion'] for sublist in [v['suggestions'] for v in x.values()] for item in sublist])
325
  }
326
  | RunnablePassthrough.assign(
327
  additional_examples=lambda x: (
 
331
  )
332
  )
333
 
334
+ def parse_input_str(self, input_str):
 
 
 
335
  try:
336
+ example_dict = json.loads(input_str)
337
+ except ValueError:
338
  try:
339
+ example_dict = yaml.safe_load(input_str)
340
+ except yaml.YAMLError as e:
341
+ raise ValueError("Invalid input format. Expected a JSON or YAML object.") from e
342
+
343
+ # If example_dict is a list, filter out invalid items
344
+ if isinstance(example_dict, list):
345
+ example_dict = [item for item in example_dict if isinstance(item, dict) and 'input' in item and 'output' in item]
346
+
347
+ # If example_dict is not a list, check if it's a valid dict
348
+ elif not isinstance(example_dict, dict) or 'input' not in example_dict or 'output' not in example_dict:
349
+ raise ValueError("Invalid input format. Expected an object with 'input' and 'output' fields.")
350
+
351
+ return example_dict
352
 
353
+ def load_and_validate_input(self, input_dict):
354
+ input_str = input_dict["input_str"]
355
+ generating_batch_size = input_dict.get("generating_batch_size")
356
 
357
+ try:
358
+ example_dict = self.parse_input_str(input_str)
359
  # Move the original content to a key named 'example'
360
+ input_dict = {"example": example_dict}
361
+ if generating_batch_size is not None:
362
+ input_dict["generating_batch_size"] = generating_batch_size
363
 
364
  return input_dict
365
 
 
374
  def generate_description(self, input_str, generating_batch_size=3):
375
  chain = (
376
  self.input_loader
377
+ | RunnablePassthrough.assign(raw_example=lambda x: json.dumps(x["example"], ensure_ascii=False))
378
+ | RunnablePassthrough.assign(description=self.description_chain)
379
+ | {
380
+ "description": lambda x: x["description"],
381
+ "suggestions": {
382
+ "specification": self.specification_suggestions_chain,
383
+ "generalization": self.generalization_suggestions_chain
384
+ } | RunnableLambda(lambda x: [item['suggestion'] for sublist in [v['suggestions'] for v in x.values()] for item in sublist])
385
+ }
386
  )
387
  return chain.invoke({
388
  "input_str": input_str,
389
  "generating_batch_size": generating_batch_size
390
  })
391
+
392
+ def update_description(self, input_str, description, suggestions):
393
+ # package array suggestions into a JSON array
394
+ suggestions_str = json.dumps(suggestions, ensure_ascii=False)
395
+
396
+ # return the updated description with new suggestions
397
+ chain = (
398
+ RunnablePassthrough.assign(
399
+ description=self.description_updating_chain
400
+ )
401
+ | {
402
+ "description": lambda x: x["description"],
403
+ "suggestions": {
404
+ "specification": self.specification_suggestions_chain,
405
+ "generalization": self.generalization_suggestions_chain
406
+ } | RunnableLambda(lambda x: [item['suggestion'] for sublist in [v['suggestions'] for v in x.values()] for item in sublist])
407
+ }
408
+ )
409
+ return chain.invoke({
410
+ "raw_example": input_str,
411
+ "description": description,
412
+ "suggestions": suggestions_str
413
+ })
414
 
415
  def analyze_input(self, description):
416
  return self.input_analysis_chain.invoke(description)