yaleh commited on
Commit
336eca2
·
1 Parent(s): 40e1bf4

Add with_retry to JSON chains.

Browse files
app/streamlit_sample_generator.py CHANGED
@@ -174,9 +174,6 @@ if 'examples_dataframe' not in st.session_state:
174
  if 'selected_example' not in st.session_state:
175
  st.session_state.selected_example = None
176
 
177
- # if 'input_file' not in st.session_state:
178
- # st.session_state.input_file = None
179
-
180
 
181
  def update_description_output_text():
182
  input_json = package_input_data()
 
174
  if 'selected_example' not in st.session_state:
175
  st.session_state.selected_example = None
176
 
 
 
 
177
 
178
  def update_description_output_text():
179
  input_json = package_input_data()
meta_prompt/sample_generator.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  import yaml
3
  from langchain.prompts import ChatPromptTemplate
4
  from langchain.schema.output_parser import StrOutputParser
@@ -261,12 +262,28 @@ class TaskDescriptionGenerator:
261
 
262
  self.description_chain = self.description_prompt | model | output_parser
263
  self.description_updating_chain = self.description_updating_prompt | model | output_parser
264
- self.specification_suggestions_chain = (self.specification_suggestions_prompt | json_model | json_parse).with_fallbacks([RunnableLambda(lambda x: {"dimensions": [], "suggestions": []})])
265
- self.generalization_suggestions_chain = (self.generalization_suggestions_prompt | json_model | json_parse).with_fallbacks([RunnableLambda(lambda x: {"dimensions": [], "suggestions": []})])
 
 
 
 
 
 
 
 
266
  self.input_analysis_chain = self.input_analysis_prompt | model | output_parser
267
  self.briefs_chain = self.briefs_prompt | model | output_parser
268
- self.examples_from_briefs_chain = (self.examples_from_briefs_prompt | json_model | json_parse).with_fallbacks([RunnableLambda(lambda x: {"examples": []})])
269
- self.examples_directly_chain = (self.examples_directly_prompt | json_model | json_parse).with_fallbacks([RunnableLambda(lambda x: {"examples": []})])
 
 
 
 
 
 
 
 
270
 
271
  # New sub-chain for loading and validating input
272
  self.input_loader = RunnableLambda(self.load_and_validate_input)
 
1
  import json
2
+ from openai import BadRequestError
3
  import yaml
4
  from langchain.prompts import ChatPromptTemplate
5
  from langchain.schema.output_parser import StrOutputParser
 
262
 
263
  self.description_chain = self.description_prompt | model | output_parser
264
  self.description_updating_chain = self.description_updating_prompt | model | output_parser
265
+ self.specification_suggestions_chain = (self.specification_suggestions_prompt | json_model | json_parse).with_retry(
266
+ retry_if_exception_type=(BadRequestError,), # Retry only on ValueError
267
+ wait_exponential_jitter=True, # Add jitter to the exponential backoff
268
+ stop_after_attempt=2 # Try twice
269
+ ).with_fallbacks([RunnableLambda(lambda x: {"dimensions": [], "suggestions": []})])
270
+ self.generalization_suggestions_chain = (self.generalization_suggestions_prompt | json_model | json_parse).with_retry(
271
+ retry_if_exception_type=(BadRequestError,), # Retry only on ValueError
272
+ wait_exponential_jitter=True, # Add jitter to the exponential backoff
273
+ stop_after_attempt=2 # Try twice
274
+ ).with_fallbacks([RunnableLambda(lambda x: {"dimensions": [], "suggestions": []})])
275
  self.input_analysis_chain = self.input_analysis_prompt | model | output_parser
276
  self.briefs_chain = self.briefs_prompt | model | output_parser
277
+ self.examples_from_briefs_chain = (self.examples_from_briefs_prompt | json_model | json_parse).with_retry(
278
+ retry_if_exception_type=(BadRequestError,), # Retry only on ValueError
279
+ wait_exponential_jitter=True, # Add jitter to the exponential backoff
280
+ stop_after_attempt=2 # Try twice
281
+ ).with_fallbacks([RunnableLambda(lambda x: {"examples": []})])
282
+ self.examples_directly_chain = (self.examples_directly_prompt | json_model | json_parse).with_retry(
283
+ retry_if_exception_type=(BadRequestError,), # Retry only on ValueError
284
+ wait_exponential_jitter=True, # Add jitter to the exponential backoff
285
+ stop_after_attempt=2 # Try twice
286
+ ).with_fallbacks([RunnableLambda(lambda x: {"examples": []})])
287
 
288
  # New sub-chain for loading and validating input
289
  self.input_loader = RunnableLambda(self.load_and_validate_input)