sagawa commited on
Commit
64b6831
·
verified ·
1 Parent(s): fa55e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -62
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gc
2
  import os
3
- import sys
4
  import warnings
5
  from types import SimpleNamespace
6
 
@@ -14,6 +13,7 @@ from generation_utils import (
14
  decode_output,
15
  save_multiple_predictions,
16
  )
 
17
  from torch.utils.data import DataLoader
18
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
19
  from utils import seed_everything
@@ -111,6 +111,8 @@ with st.sidebar:
111
  model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
112
  model_help = "Default model for yield prediction."
113
  input_max_length_default = 400
 
 
114
 
115
  model_name_or_path = st.selectbox(
116
  "Model",
@@ -118,15 +120,15 @@ with st.sidebar:
118
  index=0,
119
  help=model_help,
120
  )
121
-
122
- num_beams = st.slider(
123
- "Beam size",
124
- min_value=1,
125
- max_value=10,
126
- value=5,
127
- step=1,
128
- help="Number of beams for beam search.",
129
- )
130
 
131
  seed = st.number_input(
132
  "Random seed",
@@ -187,9 +189,12 @@ def load_tokenizer(model_ref: str):
187
 
188
 
189
  @st.cache_resource(show_spinner=True)
190
- def load_model(model_ref: str, device_str: str):
191
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
192
- model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
 
 
 
193
  model.to(torch.device(device_str))
194
  model.eval()
195
  return model
@@ -253,14 +258,22 @@ if run:
253
  else:
254
  # Build config object expected by your dataset/utils
255
  CFG = SimpleNamespace(
256
- num_beams=int(num_beams),
257
- num_return_sequences=int(num_beams), # tie to beams by default
 
 
 
258
  model_name_or_path=model_name_or_path,
259
  input_column="input",
260
- input_max_length=int(input_max_length),
261
- output_max_length=int(output_max_length),
262
- output_min_length=int(output_min_length),
263
- model="t5",
 
 
 
 
 
264
  seed=int(seed),
265
  batch_size=int(batch_size),
266
  )
@@ -272,7 +285,7 @@ if run:
272
  try:
273
  tokenizer = load_tokenizer(CFG.model_name_or_path)
274
  CFG.tokenizer = tokenizer
275
- model = load_model(CFG.model_name_or_path, device.type)
276
  status.update(label="Model ready.", state="complete")
277
  except Exception as e:
278
  st.session_state["last_error"] = f"Failed to load model: {e}"
@@ -296,51 +309,60 @@ if run:
296
  drop_last=False,
297
  )
298
 
299
- # Generation loop with progress
300
- all_sequences, all_scores = [], []
301
- total = len(dataloader)
302
- progress = st.progress(0, text="Generating predictions...")
303
- info_placeholder = st.empty()
304
-
305
- for i, inputs in enumerate(dataloader, start=1):
306
- inputs = {k: v.to(device) for k, v in inputs.items()}
307
- with torch.no_grad():
308
- output = model.generate(
309
- **inputs,
310
- min_length=CFG.output_min_length,
311
- max_length=CFG.output_max_length,
312
- num_beams=CFG.num_beams,
313
- num_return_sequences=CFG.num_return_sequences,
314
- return_dict_in_generate=True,
315
- output_scores=True,
316
- )
317
- sequences, scores = decode_output(output, CFG)
318
- all_sequences.extend(sequences)
319
- if scores:
320
- all_scores.extend(scores)
321
-
322
- del output
323
- if device.type == "cuda":
324
- torch.cuda.empty_cache()
325
- gc.collect()
326
-
327
- progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
328
- info_placeholder.caption(f"Processed batch {i} of {total}")
329
-
330
- progress.empty()
331
- info_placeholder.empty()
332
-
333
- # Save predictions
334
- try:
335
- output_df = save_multiple_predictions(
336
- input_df, all_sequences, all_scores, CFG
337
- )
338
  st.session_state["results_df"] = output_df
339
  st.success("Prediction complete.")
340
- except Exception as e:
341
- st.session_state["last_error"] = f"Failed to assemble output: {e}"
342
- st.error(st.session_state["last_error"])
343
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  # ------------------------------
346
  # Results
 
1
  import gc
2
  import os
 
3
  import warnings
4
  from types import SimpleNamespace
5
 
 
13
  decode_output,
14
  save_multiple_predictions,
15
  )
16
+ from models import ReactionT5Yield2
17
  from torch.utils.data import DataLoader
18
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
19
  from utils import seed_everything
 
111
  model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
112
  model_help = "Default model for yield prediction."
113
  input_max_length_default = 400
114
+ from task_yield.train import preprocess_df
115
+ from task_yield.prediction import inference_fn
116
 
117
  model_name_or_path = st.selectbox(
118
  "Model",
 
120
  index=0,
121
  help=model_help,
122
  )
123
+ if task != "yield prediction":
124
+ num_beams = st.slider(
125
+ "Beam size",
126
+ min_value=1,
127
+ max_value=10,
128
+ value=5,
129
+ step=1,
130
+ help="Number of beams for beam search.",
131
+ )
132
 
133
  seed = st.number_input(
134
  "Random seed",
 
189
 
190
 
191
  @st.cache_resource(show_spinner=True)
192
+ def load_model(model_ref: str, device_str: str, task: str):
193
  resolved = os.path.abspath(model_ref) if os.path.exists(model_ref) else model_ref
194
+ if task != "yield prediction":
195
+ model = AutoModelForSeq2SeqLM.from_pretrained(resolved)
196
+ else:
197
+ model = ReactionT5Yield2.from_pretrained(resolved)
198
  model.to(torch.device(device_str))
199
  model.eval()
200
  return model
 
258
  else:
259
  # Build config object expected by your dataset/utils
260
  CFG = SimpleNamespace(
261
+ task=task,
262
+ num_beams=int(num_beams) if task != "yield prediction" else None,
263
+ num_return_sequences=int(num_beams)
264
+ if task != "yield prediction"
265
+ else None, # tie to beams by default
266
  model_name_or_path=model_name_or_path,
267
  input_column="input",
268
+ input_max_length=int(input_max_length)
269
+ if task != "yield prediction"
270
+ else None,
271
+ output_max_length=int(output_max_length)
272
+ if task != "yield prediction"
273
+ else None,
274
+ output_min_length=int(output_min_length)
275
+ if task != "yield prediction"
276
+ else None,
277
  seed=int(seed),
278
  batch_size=int(batch_size),
279
  )
 
285
  try:
286
  tokenizer = load_tokenizer(CFG.model_name_or_path)
287
  CFG.tokenizer = tokenizer
288
+ model = load_model(CFG.model_name_or_path, device.type, task)
289
  status.update(label="Model ready.", state="complete")
290
  except Exception as e:
291
  st.session_state["last_error"] = f"Failed to load model: {e}"
 
309
  drop_last=False,
310
  )
311
 
312
+ if task == "yield prediction":
313
+ # Use custom inference function for yield prediction
314
+ prediction = inference_fn(dataloader, model, CFG)
315
+ output_df = input_df.copy()
316
+ output_df["prediction"] = prediction
317
+ output_df["prediction"] = output_df["prediction"].clip(lower=0.0, upper=100.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  st.session_state["results_df"] = output_df
319
  st.success("Prediction complete.")
320
+ else:
321
+ # Generation loop with progress
322
+ all_sequences, all_scores = [], []
323
+ total = len(dataloader)
324
+ progress = st.progress(0, text="Generating predictions...")
325
+ info_placeholder = st.empty()
326
+
327
+ for i, inputs in enumerate(dataloader, start=1):
328
+ inputs = {k: v.to(device) for k, v in inputs.items()}
329
+ with torch.no_grad():
330
+ output = model.generate(
331
+ **inputs,
332
+ min_length=CFG.output_min_length,
333
+ max_length=CFG.output_max_length,
334
+ num_beams=CFG.num_beams,
335
+ num_return_sequences=CFG.num_return_sequences,
336
+ return_dict_in_generate=True,
337
+ output_scores=True,
338
+ )
339
+ sequences, scores = decode_output(output, CFG)
340
+ all_sequences.extend(sequences)
341
+ if scores:
342
+ all_scores.extend(scores)
343
+
344
+ del output
345
+ if device.type == "cuda":
346
+ torch.cuda.empty_cache()
347
+ gc.collect()
348
+
349
+ progress.progress(i / total, text=f"Generating predictions... {i}/{total}")
350
+ info_placeholder.caption(f"Processed batch {i} of {total}")
351
+
352
+ progress.empty()
353
+ info_placeholder.empty()
354
+
355
+ # Save predictions
356
+ try:
357
+ output_df = save_multiple_predictions(
358
+ input_df, all_sequences, all_scores, CFG
359
+ )
360
+ st.session_state["results_df"] = output_df
361
+ st.success("Prediction complete.")
362
+ except Exception as e:
363
+ st.session_state["last_error"] = f"Failed to assemble output: {e}"
364
+ st.error(st.session_state["last_error"])
365
+ st.stop()
366
 
367
  # ------------------------------
368
  # Results