santhosh97 commited on
Commit
781b338
·
1 Parent(s): ada7876

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -48,6 +48,17 @@ class UxState(str, Enum):
48
  # Command-line arguments to control some stuff for easier local testing.
49
  # Eventually may want to move everything into functions and have a
50
  # if __name__ == "main" setup instead of everything inline.
 
 
 
 
 
 
 
 
 
 
 
51
  def setup_session_state():
52
  if "key" not in st.session_state:
53
  st.session_state["key"] = uuid.uuid4().hex
@@ -190,10 +201,22 @@ def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_
190
  return s3_path
191
 
192
  def train_model(model_inputs):
193
- api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
194
- model_key = "bd2c55f5-84bb-40f9-82fb-196ca68b1c1d"
195
- st.markdown(str(model_inputs))
196
- _ = banana.run(api_key, model_key, model_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  def switch_ux_state(new_state: UxState):
199
  st.session_state['ux_state'] = new_state
 
48
  # Command-line arguments to control some stuff for easier local testing.
49
  # Eventually may want to move everything into functions and have a
50
  # if __name__ == "main" setup instead of everything inline.
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument(
53
+ "--dry-run", action="store_true",
54
+ help="Skip sending train request to backend server.",
55
+ )
56
+ parser.add_argument(
57
+ "--train-endpoint-url", default=None,
58
+ help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
59
+ )
60
+ cli_args = parser.parse_args()
61
+
62
  def setup_session_state():
63
  if "key" not in st.session_state:
64
  st.session_state["key"] = uuid.uuid4().hex
 
201
  return s3_path
202
 
203
  def train_model(model_inputs):
204
+ if cli_args.dry_run:
205
+ logger.info("Skipping model training since --dry-run is enabled.")
206
+ logger.info(f"model_inputs: {model_inputs}")
207
+ return
208
+
209
+ if cli_args.train_endpoint_url is None:
210
+ # Use banana backend
211
+ api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
212
+ model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
213
+ st.markdown(str(model_inputs))
214
+ print(model_inputs)
215
+ _ = banana.run(api_key, model_key, model_inputs)
216
+ else:
217
+ # Send request directly to specified url
218
+ _ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
219
+
220
 
221
  def switch_ux_state(new_state: UxState):
222
  st.session_state['ux_state'] = new_state