Update app.py
Browse files
app.py
CHANGED
@@ -2,23 +2,19 @@
|
|
2 |
|
3 |
import os
|
4 |
import pandas as pd
|
5 |
-
import gradio as gr
|
6 |
import torch
|
7 |
-
|
8 |
-
|
9 |
-
AutoTokenizer,
|
10 |
-
AutoModelForImageTextToText
|
11 |
-
)
|
12 |
from sklearn.model_selection import train_test_split
|
13 |
|
14 |
-
# 1)
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
16 |
if not HF_TOKEN:
|
17 |
-
raise RuntimeError("Missing HF_TOKEN env
|
18 |
|
19 |
MODEL_ID = "google/gemma-3n-e2b-it"
|
20 |
|
21 |
-
# 2) Eagerly load the
|
22 |
processor = AutoProcessor.from_pretrained(
|
23 |
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
|
24 |
)
|
@@ -26,9 +22,12 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
26 |
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
|
27 |
)
|
28 |
|
29 |
-
def
|
30 |
-
"""
|
31 |
-
|
|
|
|
|
|
|
32 |
model = AutoModelForImageTextToText.from_pretrained(
|
33 |
MODEL_ID,
|
34 |
trust_remote_code=True,
|
@@ -39,6 +38,7 @@ def generate_all_and_split():
|
|
39 |
device = next(model.parameters()).device
|
40 |
|
41 |
def to_soap(text: str) -> str:
|
|
|
42 |
inputs = processor.apply_chat_template(
|
43 |
[
|
44 |
{"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
|
@@ -57,10 +57,11 @@ def generate_all_and_split():
|
|
57 |
temperature=0.1,
|
58 |
pad_token_id=processor.tokenizer.eos_token_id
|
59 |
)
|
|
|
60 |
prompt_len = inputs["input_ids"].shape[-1]
|
61 |
return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
|
62 |
|
63 |
-
# b) Generate 100 doc_notes
|
64 |
docs, gts = [], []
|
65 |
for i in range(1, 101):
|
66 |
doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
|
@@ -69,13 +70,14 @@ def generate_all_and_split():
|
|
69 |
if i % 20 == 0:
|
70 |
torch.cuda.empty_cache()
|
71 |
|
72 |
-
# c)
|
73 |
df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
|
74 |
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
|
75 |
|
|
|
76 |
os.makedirs("outputs", exist_ok=True)
|
77 |
|
78 |
-
# d) Inference on train → inference.tsv
|
79 |
train_preds = [to_soap(d) for d in train_df["doc_note"]]
|
80 |
inf = train_df.reset_index(drop=True).copy()
|
81 |
inf["id"] = inf.index + 1
|
@@ -84,25 +86,33 @@ def generate_all_and_split():
|
|
84 |
"outputs/inference.tsv", sep="\t", index=False
|
85 |
)
|
86 |
|
87 |
-
# e) Inference on test → eval.csv
|
88 |
test_preds = [to_soap(d) for d in test_df["doc_note"]]
|
89 |
pd.DataFrame({
|
90 |
"id": range(1, len(test_preds)+1),
|
91 |
"predicted_soap": test_preds
|
92 |
}).to_csv("outputs/eval.csv", index=False)
|
93 |
|
|
|
94 |
return (
|
95 |
-
"✅
|
96 |
-
|
97 |
-
|
98 |
)
|
99 |
|
100 |
-
# 3) Gradio UI
|
101 |
with gr.Blocks() as demo:
|
102 |
-
gr.Markdown("
|
103 |
-
|
104 |
-
status
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
if __name__ == "__main__":
|
108 |
demo.launch()
|
|
|
2 |
|
3 |
import os
|
4 |
import pandas as pd
|
|
|
5 |
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from transformers import AutoProcessor, AutoTokenizer, AutoModelForImageTextToText
|
|
|
|
|
|
|
8 |
from sklearn.model_selection import train_test_split
|
9 |
|
10 |
+
# 1) Configuration
|
11 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
12 |
if not HF_TOKEN:
|
13 |
+
raise RuntimeError("Missing HF_TOKEN in env vars – please add it under Settings → Secrets")
|
14 |
|
15 |
MODEL_ID = "google/gemma-3n-e2b-it"
|
16 |
|
17 |
+
# 2) Eagerly load only the processor & tokenizer (fast startup)
|
18 |
processor = AutoProcessor.from_pretrained(
|
19 |
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
|
20 |
)
|
|
|
22 |
MODEL_ID, trust_remote_code=True, token=HF_TOKEN
|
23 |
)
|
24 |
|
25 |
+
def generate_and_export():
|
26 |
+
"""
|
27 |
+
On button click: lazily load the 8‑bit model, generate 100 doc→SOAP pairs,
|
28 |
+
split 70/30, run inference & eval, write files, and return download links.
|
29 |
+
"""
|
30 |
+
# a) Load full model in 8‑bit
|
31 |
model = AutoModelForImageTextToText.from_pretrained(
|
32 |
MODEL_ID,
|
33 |
trust_remote_code=True,
|
|
|
38 |
device = next(model.parameters()).device
|
39 |
|
40 |
def to_soap(text: str) -> str:
|
41 |
+
# wrap the chat‐template + generate call
|
42 |
inputs = processor.apply_chat_template(
|
43 |
[
|
44 |
{"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
|
|
|
57 |
temperature=0.1,
|
58 |
pad_token_id=processor.tokenizer.eos_token_id
|
59 |
)
|
60 |
+
# strip off prompt tokens
|
61 |
prompt_len = inputs["input_ids"].shape[-1]
|
62 |
return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
|
63 |
|
64 |
+
# b) Generate 100 synthetic doc_notes & ground_truth SOAPs
|
65 |
docs, gts = [], []
|
66 |
for i in range(1, 101):
|
67 |
doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
|
|
|
70 |
if i % 20 == 0:
|
71 |
torch.cuda.empty_cache()
|
72 |
|
73 |
+
# c) 70/30 split
|
74 |
df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
|
75 |
train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
|
76 |
|
77 |
+
# ensure outputs dir exists
|
78 |
os.makedirs("outputs", exist_ok=True)
|
79 |
|
80 |
+
# d) Inference on train split → outputs/inference.tsv
|
81 |
train_preds = [to_soap(d) for d in train_df["doc_note"]]
|
82 |
inf = train_df.reset_index(drop=True).copy()
|
83 |
inf["id"] = inf.index + 1
|
|
|
86 |
"outputs/inference.tsv", sep="\t", index=False
|
87 |
)
|
88 |
|
89 |
+
# e) Inference on test split → outputs/eval.csv
|
90 |
test_preds = [to_soap(d) for d in test_df["doc_note"]]
|
91 |
pd.DataFrame({
|
92 |
"id": range(1, len(test_preds)+1),
|
93 |
"predicted_soap": test_preds
|
94 |
}).to_csv("outputs/eval.csv", index=False)
|
95 |
|
96 |
+
# return status + file paths for download
|
97 |
return (
|
98 |
+
"✅ Generation complete!",
|
99 |
+
"outputs/inference.tsv",
|
100 |
+
"outputs/eval.csv"
|
101 |
)
|
102 |
|
103 |
+
# 3) Gradio UI
|
104 |
with gr.Blocks() as demo:
|
105 |
+
gr.Markdown("# Gemma‑3n SOAP Generator 🩺")
|
106 |
+
generate_btn = gr.Button("Generate & Export 100 Notes")
|
107 |
+
status = gr.Textbox(interactive=False, label="Status")
|
108 |
+
inf_file = gr.File(label="Download inference.tsv")
|
109 |
+
eval_file = gr.File(label="Download eval.csv")
|
110 |
+
|
111 |
+
generate_btn.click(
|
112 |
+
fn=generate_and_export,
|
113 |
+
inputs=None,
|
114 |
+
outputs=[status, inf_file, eval_file]
|
115 |
+
)
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
demo.launch()
|