Bonosa2 commited on
Commit
50830c3
Β·
verified Β·
1 Parent(s): 724fa34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -84
app.py CHANGED
@@ -1,23 +1,28 @@
 
 
1
  import os
 
2
  import pandas as pd
3
  import torch
4
  import gradio as gr
5
  from transformers import (
 
6
  AutoProcessor,
7
  AutoTokenizer,
8
  AutoModelForImageTextToText
9
  )
10
  from sklearn.model_selection import train_test_split
11
 
 
 
 
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
  if not HF_TOKEN:
14
- raise RuntimeError(
15
- "Missing HF_TOKEN in env vars – add it under Settings β†’ Secrets"
16
- )
17
-
18
  MODEL_ID = "google/gemma-3n-e2b-it"
19
 
20
- # Load processor & tokenizer at top level for fast startup
21
  processor = AutoProcessor.from_pretrained(
22
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
23
  )
@@ -25,93 +30,97 @@ tokenizer = AutoTokenizer.from_pretrained(
25
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
26
  )
27
 
 
28
  def generate_and_export():
29
- """
30
- On button click: load full model, generate 100 notes,
31
- split 70/30, run inference & eval, save files, return download links.
32
- """
33
- # Load the heavy model here
34
- model = AutoModelForImageTextToText.from_pretrained(
35
- MODEL_ID,
36
- trust_remote_code=True,
37
- token=HF_TOKEN,
38
- torch_dtype=torch.float16,
39
- device_map="auto"
40
- )
41
- device = next(model.parameters()).device
42
-
43
- def to_soap(text: str) -> str:
44
- inputs = processor.apply_chat_template(
45
- [
46
- {"role": "system", "content":[{"type":"text","text":"You are a medical AI assistant."}]},
47
- {"role": "user", "content":[{"type":"text","text":text}]}
48
- ],
49
- add_generation_prompt=True,
50
- tokenize=True,
51
- return_tensors="pt",
52
- return_dict=True
53
- ).to(device)
54
- out = model.generate(
55
- **inputs,
56
- max_new_tokens=400,
57
- do_sample=True,
58
- top_p=0.95,
59
- temperature=0.1,
60
- pad_token_id=processor.tokenizer.eos_token_id,
61
- use_cache=False
62
  )
63
- prompt_len = inputs["input_ids"].shape[-1]
64
- return processor.batch_decode(
65
- out[:, prompt_len:], skip_special_tokens=True
66
- )[0].strip()
67
-
68
- # Generate 100 docs + GTs
69
- docs, gts = [], []
70
- for i in range(1, 101):
71
- doc = to_soap(
72
- "Generate a realistic, concise doctor's progress note for a single patient encounter."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  )
74
- docs.append(doc)
75
- gts.append(to_soap(doc))
76
- if i % 20 == 0:
77
- torch.cuda.empty_cache()
78
-
79
- # Split 70/30
80
- df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
81
- train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
82
-
83
- # Ensure outputs dir
84
- os.makedirs("outputs", exist_ok=True)
85
-
86
- # Inference on train β†’ inference.tsv
87
- train_preds = [to_soap(d) for d in train_df["doc_note"]]
88
- inf = train_df.reset_index(drop=True).copy()
89
- inf["id"] = inf.index + 1
90
- inf["predicted_soap"] = train_preds
91
- inf[["id","ground_truth_soap","predicted_soap"]].to_csv(
92
- "outputs/inference.tsv", sep="\t", index=False
93
- )
94
 
95
- # Inference on test β†’ eval.csv
96
- test_preds = [to_soap(d) for d in test_df["doc_note"]]
97
- pd.DataFrame({
98
- "id": range(1, len(test_preds) + 1),
99
- "predicted_soap": test_preds
100
- }).to_csv("outputs/eval.csv", index=False)
101
-
102
- return (
103
- "βœ… Done!",
104
- "outputs/inference.tsv",
105
- "outputs/eval.csv"
106
- )
 
 
 
 
 
 
 
107
 
108
- # Build Gradio interface (starts immediately)
109
  with gr.Blocks() as demo:
110
  gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
111
- btn = gr.Button("Generate & Export 100 Notes")
112
- status = gr.Textbox(interactive=False, label="Status")
113
  inf_file = gr.File(label="Download inference.tsv")
114
- eval_file = gr.File(label="Download eval.csv")
115
 
116
  btn.click(
117
  fn=generate_and_export,
 
1
+ # app.py
2
+
3
  import os
4
+ import traceback
5
  import pandas as pd
6
  import torch
7
  import gradio as gr
8
  from transformers import (
9
+ logging,
10
  AutoProcessor,
11
  AutoTokenizer,
12
  AutoModelForImageTextToText
13
  )
14
  from sklearn.model_selection import train_test_split
15
 
16
+ # ─── Silence unrecognized‐flag warnings ────────────────────────────────────────
17
+ logging.set_verbosity_error()
18
+
19
+ # ─── Configuration ────────────────────────────────────────────────────────────
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  if not HF_TOKEN:
22
+ raise RuntimeError("Missing HF_TOKEN in env vars – set it under Space Settings β†’ Secrets")
 
 
 
23
  MODEL_ID = "google/gemma-3n-e2b-it"
24
 
25
+ # ─── Fast startup: load only the small pieces ──────────────────────────────────
26
  processor = AutoProcessor.from_pretrained(
27
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
28
  )
 
30
  MODEL_ID, trust_remote_code=True, token=HF_TOKEN
31
  )
32
 
33
+ # ─── Heavy work deferred until button click ───────────────────────────────────
34
  def generate_and_export():
35
+ try:
36
+ # 1) Lazy-load the full FP16 model (heavy)
37
+ model = AutoModelForImageTextToText.from_pretrained(
38
+ MODEL_ID,
39
+ trust_remote_code=True,
40
+ token=HF_TOKEN,
41
+ torch_dtype=torch.float16,
42
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
+ device = next(model.parameters()).device
45
+
46
+ # 2) Helper to generate a SOAP note from arbitrary text
47
+ def to_soap(text: str) -> str:
48
+ inputs = processor.apply_chat_template(
49
+ [
50
+ {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
51
+ {"role":"user", "content":[{"type":"text","text":text}]}
52
+ ],
53
+ add_generation_prompt=True,
54
+ tokenize=True,
55
+ return_tensors="pt",
56
+ return_dict=True
57
+ ).to(device)
58
+ out = model.generate(
59
+ **inputs,
60
+ max_new_tokens=400,
61
+ do_sample=True,
62
+ top_p=0.95,
63
+ temperature=0.1,
64
+ pad_token_id=processor.tokenizer.eos_token_id,
65
+ use_cache=False # disable HybridCache
66
+ )
67
+ prompt_len = inputs["input_ids"].shape[-1]
68
+ return processor.batch_decode(
69
+ out[:, prompt_len:], skip_special_tokens=True
70
+ )[0].strip()
71
+
72
+ # 3) Generate 100 synthetic doc notes & ground-truth SOAPs
73
+ docs, gts = [], []
74
+ for i in range(1, 101):
75
+ doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
76
+ docs.append(doc)
77
+ gts.append(to_soap(doc))
78
+ if i % 20 == 0:
79
+ torch.cuda.empty_cache()
80
+
81
+ # 4) 70/30 train-test split
82
+ df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
83
+ train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
84
+
85
+ # ensure outputs folder
86
+ os.makedirs("outputs", exist_ok=True)
87
+
88
+ # 5) Run inference on train split β†’ inference.tsv
89
+ train_preds = [to_soap(d) for d in train_df["doc_note"]]
90
+ inf = train_df.reset_index(drop=True).copy()
91
+ inf["id"] = inf.index + 1
92
+ inf["predicted_soap"] = train_preds
93
+ inf[["id","ground_truth_soap","predicted_soap"]].to_csv(
94
+ "outputs/inference.tsv", sep="\t", index=False
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ # 6) Run inference on test split β†’ eval.csv
98
+ test_preds = [to_soap(d) for d in test_df["doc_note"]]
99
+ pd.DataFrame({
100
+ "id": range(1, len(test_preds)+1),
101
+ "predicted_soap": test_preds
102
+ }).to_csv("outputs/eval.csv", index=False)
103
+
104
+ # 7) Success: return status + file paths
105
+ return (
106
+ "βœ… Generation complete! Download below ‡",
107
+ "outputs/inference.tsv",
108
+ "outputs/eval.csv"
109
+ )
110
+
111
+ except Exception as e:
112
+ # Print full traceback to the Space logs
113
+ traceback.print_exc()
114
+ # Return the error message to the UI
115
+ return (f"❌ Error: {e}", None, None)
116
 
117
+ # ─── Gradio UI ─────────────────────────────────────────────────────────────────
118
  with gr.Blocks() as demo:
119
  gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
120
+ btn = gr.Button("Generate & Export 100 Notes")
121
+ status = gr.Textbox(interactive=False, label="Status")
122
  inf_file = gr.File(label="Download inference.tsv")
123
+ eval_file= gr.File(label="Download eval.csv")
124
 
125
  btn.click(
126
  fn=generate_and_export,