Bonosa2 commited on
Commit
7fcc7f1
Β·
verified Β·
1 Parent(s): 50830c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -20
app.py CHANGED
@@ -13,7 +13,7 @@ from transformers import (
13
  )
14
  from sklearn.model_selection import train_test_split
15
 
16
- # ─── Silence unrecognized‐flag warnings ────────────────────────────────────────
17
  logging.set_verbosity_error()
18
 
19
  # ─── Configuration ────────────────────────────────────────────────────────────
@@ -22,7 +22,7 @@ if not HF_TOKEN:
22
  raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings β†’ Secrets")
23
  MODEL_ID = "google/gemma-3n-e2b-it"
24
 
25
- # ─── Fast startup: load only the small pieces ──────────────────────────────────
26
  processor = AutoProcessor.from_pretrained(
27
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
28
  )
@@ -30,10 +30,10 @@ tokenizer = AutoTokenizer.from_pretrained(
30
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
31
  )
32
 
33
- # ─── Heavy work deferred until button click ───────────────────────────────────
34
  def generate_and_export():
35
  try:
36
- # 1) Lazy-load the full FP16 model (heavy)
37
  model = AutoModelForImageTextToText.from_pretrained(
38
  MODEL_ID,
39
  trust_remote_code=True,
@@ -43,7 +43,7 @@ def generate_and_export():
43
  )
44
  device = next(model.parameters()).device
45
 
46
- # 2) Helper to generate a SOAP note from arbitrary text
47
  def to_soap(text: str) -> str:
48
  inputs = processor.apply_chat_template(
49
  [
@@ -62,30 +62,29 @@ def generate_and_export():
62
  top_p=0.95,
63
  temperature=0.1,
64
  pad_token_id=processor.tokenizer.eos_token_id,
65
- use_cache=False # disable HybridCache
66
  )
67
  prompt_len = inputs["input_ids"].shape[-1]
68
  return processor.batch_decode(
69
  out[:, prompt_len:], skip_special_tokens=True
70
  )[0].strip()
71
 
72
- # 3) Generate 100 synthetic doc notes & ground-truth SOAPs
73
  docs, gts = [], []
74
- for i in range(1, 101):
75
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
76
  docs.append(doc)
77
  gts.append(to_soap(doc))
78
- if i % 20 == 0:
79
  torch.cuda.empty_cache()
80
 
81
- # 4) 70/30 train-test split
82
  df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
83
- train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
84
 
85
- # ensure outputs folder
86
  os.makedirs("outputs", exist_ok=True)
87
 
88
- # 5) Run inference on train split β†’ inference.tsv
89
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
90
  inf = train_df.reset_index(drop=True).copy()
91
  inf["id"] = inf.index + 1
@@ -94,30 +93,28 @@ def generate_and_export():
94
  "outputs/inference.tsv", sep="\t", index=False
95
  )
96
 
97
- # 6) Run inference on test split β†’ eval.csv
98
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
99
  pd.DataFrame({
100
- "id": range(1, len(test_preds)+1),
101
  "predicted_soap": test_preds
102
  }).to_csv("outputs/eval.csv", index=False)
103
 
104
- # 7) Success: return status + file paths
105
  return (
106
- "βœ… Generation complete! Download below ‡",
107
  "outputs/inference.tsv",
108
  "outputs/eval.csv"
109
  )
110
 
111
  except Exception as e:
112
- # Print full traceback to the Space logs
113
  traceback.print_exc()
114
- # Return the error message to the UI
115
  return (f"❌ Error: {e}", None, None)
116
 
117
  # ─── Gradio UI ──────────────────��──────────────────────────────────────────────
118
  with gr.Blocks() as demo:
119
  gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
120
- btn = gr.Button("Generate & Export 100 Notes")
121
  status = gr.Textbox(interactive=False, label="Status")
122
  inf_file = gr.File(label="Download inference.tsv")
123
  eval_file= gr.File(label="Download eval.csv")
 
13
  )
14
  from sklearn.model_selection import train_test_split
15
 
16
+ # ─── Silence irrelevant warnings ───────────────────────────────────────────────
17
  logging.set_verbosity_error()
18
 
19
  # ─── Configuration ────────────────────────────────────────────────────────────
 
22
  raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings β†’ Secrets")
23
  MODEL_ID = "google/gemma-3n-e2b-it"
24
 
25
+ # ─── Fast startup: load only processor & tokenizer ─────────────────────────────
26
  processor = AutoProcessor.from_pretrained(
27
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
28
  )
 
30
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
31
  )
32
 
33
+ # ─── Heavy work runs on button click ───────────────────────────────────────────
34
  def generate_and_export():
35
  try:
36
+ # 1) Lazy‑load the full FP16 model
37
  model = AutoModelForImageTextToText.from_pretrained(
38
  MODEL_ID,
39
  trust_remote_code=True,
 
43
  )
44
  device = next(model.parameters()).device
45
 
46
+ # 2) Text→SOAP helper
47
  def to_soap(text: str) -> str:
48
  inputs = processor.apply_chat_template(
49
  [
 
62
  top_p=0.95,
63
  temperature=0.1,
64
  pad_token_id=processor.tokenizer.eos_token_id,
65
+ use_cache=False
66
  )
67
  prompt_len = inputs["input_ids"].shape[-1]
68
  return processor.batch_decode(
69
  out[:, prompt_len:], skip_special_tokens=True
70
  )[0].strip()
71
 
72
+ # 3) Generate 20 doc notes + ground truths
73
  docs, gts = [], []
74
+ for i in range(1, 21):
75
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
76
  docs.append(doc)
77
  gts.append(to_soap(doc))
78
+ if i % 5 == 0:
79
  torch.cuda.empty_cache()
80
 
81
+ # 4) Split into 15 train / 5 test
82
  df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
83
+ train_df, test_df = train_test_split(df, test_size=5, random_state=42)
84
 
 
85
  os.makedirs("outputs", exist_ok=True)
86
 
87
+ # 5) Inference on train split β†’ outputs/inference.tsv
88
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
89
  inf = train_df.reset_index(drop=True).copy()
90
  inf["id"] = inf.index + 1
 
93
  "outputs/inference.tsv", sep="\t", index=False
94
  )
95
 
96
+ # 6) Inference on test split β†’ outputs/eval.csv
97
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
98
  pd.DataFrame({
99
+ "id": range(1, len(test_preds) + 1),
100
  "predicted_soap": test_preds
101
  }).to_csv("outputs/eval.csv", index=False)
102
 
103
+ # 7) Return status + file paths for download
104
  return (
105
+ "βœ… Done with 20 notes (15 train / 5 test)!",
106
  "outputs/inference.tsv",
107
  "outputs/eval.csv"
108
  )
109
 
110
  except Exception as e:
 
111
  traceback.print_exc()
 
112
  return (f"❌ Error: {e}", None, None)
113
 
114
  # ─── Gradio UI ──────────────────��──────────────────────────────────────────────
115
  with gr.Blocks() as demo:
116
  gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
117
+ btn = gr.Button("Generate & Export 20 Notes")
118
  status = gr.Textbox(interactive=False, label="Status")
119
  inf_file = gr.File(label="Download inference.tsv")
120
  eval_file= gr.File(label="Download eval.csv")