Bonosa2 commited on
Commit
d23d60c
·
verified ·
1 Parent(s): 87e16aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -2,23 +2,19 @@
2
 
3
  import os
4
  import pandas as pd
5
- import gradio as gr
6
  import torch
7
- from transformers import (
8
- AutoProcessor,
9
- AutoTokenizer,
10
- AutoModelForImageTextToText
11
- )
12
  from sklearn.model_selection import train_test_split
13
 
14
- # 1) Retrieve your HF_TOKEN from environment (set in Space Settings → Secrets)
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
  if not HF_TOKEN:
17
- raise RuntimeError("Missing HF_TOKEN env var! Please add it in your Space settings → Secrets.")
18
 
19
  MODEL_ID = "google/gemma-3n-e2b-it"
20
 
21
- # 2) Eagerly load the small bits (processor & tokenizer) so the UI starts fast
22
  processor = AutoProcessor.from_pretrained(
23
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
24
  )
@@ -26,9 +22,12 @@ tokenizer = AutoTokenizer.from_pretrained(
26
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
27
  )
28
 
29
- def generate_all_and_split():
30
- """Called when the user clicks the button—loads full model, generates & saves files."""
31
- # a) Lazy‑load the 8‑bit quantized model (heavy)
 
 
 
32
  model = AutoModelForImageTextToText.from_pretrained(
33
  MODEL_ID,
34
  trust_remote_code=True,
@@ -39,6 +38,7 @@ def generate_all_and_split():
39
  device = next(model.parameters()).device
40
 
41
  def to_soap(text: str) -> str:
 
42
  inputs = processor.apply_chat_template(
43
  [
44
  {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
@@ -57,10 +57,11 @@ def generate_all_and_split():
57
  temperature=0.1,
58
  pad_token_id=processor.tokenizer.eos_token_id
59
  )
 
60
  prompt_len = inputs["input_ids"].shape[-1]
61
  return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
62
 
63
- # b) Generate 100 doc_notes + ground_truth SOAPs
64
  docs, gts = [], []
65
  for i in range(1, 101):
66
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
@@ -69,13 +70,14 @@ def generate_all_and_split():
69
  if i % 20 == 0:
70
  torch.cuda.empty_cache()
71
 
72
- # c) Split 70/30
73
  df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
74
  train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
75
 
 
76
  os.makedirs("outputs", exist_ok=True)
77
 
78
- # d) Inference on train → inference.tsv
79
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
80
  inf = train_df.reset_index(drop=True).copy()
81
  inf["id"] = inf.index + 1
@@ -84,25 +86,33 @@ def generate_all_and_split():
84
  "outputs/inference.tsv", sep="\t", index=False
85
  )
86
 
87
- # e) Inference on test → eval.csv
88
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
89
  pd.DataFrame({
90
  "id": range(1, len(test_preds)+1),
91
  "predicted_soap": test_preds
92
  }).to_csv("outputs/eval.csv", index=False)
93
 
 
94
  return (
95
- "✅ Done!\n"
96
- f"outputs/inference.tsv (70 rows with id, GT, pred)\n"
97
- f"outputs/eval.csv (30 rows with id, pred)"
98
  )
99
 
100
- # 3) Gradio UI—instant startup
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## Gemma‑3n SOAP Generator 🩺")
103
- btn = gr.Button("Generate & Save 100 Notes → 70/30 Split → inference & eval")
104
- status = gr.Textbox(interactive=False, label="Status")
105
- btn.click(fn=generate_all_and_split, inputs=None, outputs=status)
 
 
 
 
 
 
 
106
 
107
  if __name__ == "__main__":
108
  demo.launch()
 
2
 
3
  import os
4
  import pandas as pd
 
5
  import torch
6
+ import gradio as gr
7
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForImageTextToText
 
 
 
8
  from sklearn.model_selection import train_test_split
9
 
10
+ # 1) Configuration
11
  HF_TOKEN = os.environ.get("HF_TOKEN")
12
  if not HF_TOKEN:
13
+ raise RuntimeError("Missing HF_TOKEN in env vars please add it under Settings → Secrets")
14
 
15
  MODEL_ID = "google/gemma-3n-e2b-it"
16
 
17
+ # 2) Eagerly load only the processor & tokenizer (fast startup)
18
  processor = AutoProcessor.from_pretrained(
19
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
20
  )
 
22
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
23
  )
24
 
25
+ def generate_and_export():
26
+ """
27
+ On button click: lazily load the 8‑bit model, generate 100 doc→SOAP pairs,
28
+ split 70/30, run inference & eval, write files, and return download links.
29
+ """
30
+ # a) Load full model in 8‑bit
31
  model = AutoModelForImageTextToText.from_pretrained(
32
  MODEL_ID,
33
  trust_remote_code=True,
 
38
  device = next(model.parameters()).device
39
 
40
  def to_soap(text: str) -> str:
41
+ # wrap the chat‐template + generate call
42
  inputs = processor.apply_chat_template(
43
  [
44
  {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
 
57
  temperature=0.1,
58
  pad_token_id=processor.tokenizer.eos_token_id
59
  )
60
+ # strip off prompt tokens
61
  prompt_len = inputs["input_ids"].shape[-1]
62
  return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
63
 
64
+ # b) Generate 100 synthetic doc_notes & ground_truth SOAPs
65
  docs, gts = [], []
66
  for i in range(1, 101):
67
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
 
70
  if i % 20 == 0:
71
  torch.cuda.empty_cache()
72
 
73
+ # c) 70/30 split
74
  df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
75
  train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
76
 
77
+ # ensure outputs dir exists
78
  os.makedirs("outputs", exist_ok=True)
79
 
80
+ # d) Inference on train split outputs/inference.tsv
81
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
82
  inf = train_df.reset_index(drop=True).copy()
83
  inf["id"] = inf.index + 1
 
86
  "outputs/inference.tsv", sep="\t", index=False
87
  )
88
 
89
+ # e) Inference on test split outputs/eval.csv
90
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
91
  pd.DataFrame({
92
  "id": range(1, len(test_preds)+1),
93
  "predicted_soap": test_preds
94
  }).to_csv("outputs/eval.csv", index=False)
95
 
96
+ # return status + file paths for download
97
  return (
98
+ "✅ Generation complete!",
99
+ "outputs/inference.tsv",
100
+ "outputs/eval.csv"
101
  )
102
 
103
+ # 3) Gradio UI
104
  with gr.Blocks() as demo:
105
+ gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
106
+ generate_btn = gr.Button("Generate & Export 100 Notes")
107
+ status = gr.Textbox(interactive=False, label="Status")
108
+ inf_file = gr.File(label="Download inference.tsv")
109
+ eval_file = gr.File(label="Download eval.csv")
110
+
111
+ generate_btn.click(
112
+ fn=generate_and_export,
113
+ inputs=None,
114
+ outputs=[status, inf_file, eval_file]
115
+ )
116
 
117
  if __name__ == "__main__":
118
  demo.launch()