Spaces:
Runtime error
Runtime error
Commit
·
781b338
1
Parent(s):
ada7876
Update app.py
Browse files
app.py
CHANGED
@@ -48,6 +48,17 @@ class UxState(str, Enum):
|
|
48 |
# Command-line arguments to control some stuff for easier local testing.
|
49 |
# Eventually may want to move everything into functions and have a
|
50 |
# if __name__ == "main" setup instead of everything inline.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def setup_session_state():
|
52 |
if "key" not in st.session_state:
|
53 |
st.session_state["key"] = uuid.uuid4().hex
|
@@ -190,10 +201,22 @@ def zip_and_upload_images(identifier: str, uploaded_files: List[BytesIO], image_
|
|
190 |
return s3_path
|
191 |
|
192 |
def train_model(model_inputs):
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
def switch_ux_state(new_state: UxState):
|
199 |
st.session_state['ux_state'] = new_state
|
|
|
48 |
# Command-line arguments to control some stuff for easier local testing.
|
49 |
# Eventually may want to move everything into functions and have a
|
50 |
# if __name__ == "main" setup instead of everything inline.
|
51 |
+
parser = argparse.ArgumentParser()
|
52 |
+
parser.add_argument(
|
53 |
+
"--dry-run", action="store_true",
|
54 |
+
help="Skip sending train request to backend server.",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--train-endpoint-url", default=None,
|
58 |
+
help="URL of backend server to send train request to. If None, use hardcoded banana setup.",
|
59 |
+
)
|
60 |
+
cli_args = parser.parse_args()
|
61 |
+
|
62 |
def setup_session_state():
|
63 |
if "key" not in st.session_state:
|
64 |
st.session_state["key"] = uuid.uuid4().hex
|
|
|
201 |
return s3_path
|
202 |
|
203 |
def train_model(model_inputs):
|
204 |
+
if cli_args.dry_run:
|
205 |
+
logger.info("Skipping model training since --dry-run is enabled.")
|
206 |
+
logger.info(f"model_inputs: {model_inputs}")
|
207 |
+
return
|
208 |
+
|
209 |
+
if cli_args.train_endpoint_url is None:
|
210 |
+
# Use banana backend
|
211 |
+
api_key = "03cdd72e-5c04-4207-bd6a-fd5712c1740e"
|
212 |
+
model_key = "1a3b4ce5-164f-4efb-9f4a-c2ad3a930d0b"
|
213 |
+
st.markdown(str(model_inputs))
|
214 |
+
print(model_inputs)
|
215 |
+
_ = banana.run(api_key, model_key, model_inputs)
|
216 |
+
else:
|
217 |
+
# Send request directly to specified url
|
218 |
+
_ = requests.post(cli_args.train_endpoint_url, json=model_inputs)
|
219 |
+
|
220 |
|
221 |
def switch_ux_state(new_state: UxState):
|
222 |
st.session_state['ux_state'] = new_state
|