Bonosa2 commited on
Commit
3314cdc
·
verified ·
1 Parent(s): ed76a54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -31
app.py CHANGED
@@ -1,31 +1,44 @@
 
 
1
  import os
2
  import pandas as pd
3
  import gradio as gr
4
- from kaggle_secrets import UserSecretsClient
5
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForImageTextToText
6
- from sklearn.model_selection import train_test_split
7
  import torch
 
 
 
 
 
 
8
 
9
- HF_TOKEN = UserSecretsClient().get_secret("HF_TOKEN")
10
- MODEL_ID = "google/gemma-3n-e2b-it"
 
 
11
 
12
- # Only load small pieces at startup
13
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, use_auth_token=HF_TOKEN)
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_auth_token=HF_TOKEN)
15
 
16
- def generate_all_notes():
17
- # 1) Load the full 8‑bit model on demand
 
 
 
 
 
 
 
 
 
18
  model = AutoModelForImageTextToText.from_pretrained(
19
  MODEL_ID,
20
  trust_remote_code=True,
21
- use_auth_token=HF_TOKEN,
22
  load_in_8bit=True,
23
  device_map="auto"
24
  )
25
  device = next(model.parameters()).device
26
 
27
- # helper to turn text→SOAP
28
- def to_soap(text):
29
  inputs = processor.apply_chat_template(
30
  [
31
  {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
@@ -40,28 +53,29 @@ def generate_all_notes():
40
  **inputs,
41
  max_new_tokens=400,
42
  do_sample=True,
43
- temperature=0.1,
44
  top_p=0.95,
 
45
  pad_token_id=processor.tokenizer.eos_token_id
46
  )
47
- return processor.batch_decode(out[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0].strip()
 
48
 
49
- # 2) Generate 100 raw docs + ground truths
50
  docs, gts = [], []
51
- for i in range(100):
52
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
53
  docs.append(doc)
54
  gts.append(to_soap(doc))
55
- if (i+1) % 20 == 0:
56
  torch.cuda.empty_cache()
57
 
58
- # 3) Split 70/30
59
- full_df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
60
- train_df, test_df = train_test_split(full_df, test_size=0.3, random_state=42)
61
 
62
  os.makedirs("outputs", exist_ok=True)
63
 
64
- # 4) Inference on train split → inference.tsv
65
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
66
  inf = train_df.reset_index(drop=True).copy()
67
  inf["id"] = inf.index + 1
@@ -70,24 +84,25 @@ def generate_all_notes():
70
  "outputs/inference.tsv", sep="\t", index=False
71
  )
72
 
73
- # 5) Inference on test split → eval.csv
74
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
75
  pd.DataFrame({
76
- "id": range(1, len(test_preds)+1),
77
  "predicted_soap": test_preds
78
  }).to_csv("outputs/eval.csv", index=False)
79
 
80
  return (
81
  "✅ Done!\n"
82
- f" outputs/inference.tsv (70 rows with id, GT & pred)\n"
83
- f" outputs/eval.csv (30 rows with id & pred)"
84
  )
85
 
 
86
  with gr.Blocks() as demo:
87
- gr.Markdown("## Gemma‑3n SOAP Generator")
88
- btn = gr.Button("Generate 100 → split 70/30 → inference & eval")
89
- out = gr.Textbox(interactive=False, label="Status")
90
- btn.click(fn=generate_all_notes, inputs=None, outputs=out)
91
 
92
- if __name__=="__main__":
93
  demo.launch()
 
1
+ # app.py
2
+
3
  import os
4
  import pandas as pd
5
  import gradio as gr
 
 
 
6
  import torch
7
+ from transformers import (
8
+ AutoProcessor,
9
+ AutoTokenizer,
10
+ AutoModelForImageTextToText
11
+ )
12
+ from sklearn.model_selection import train_test_split
13
 
14
+ # 1) Retrieve your HF_TOKEN from environment (set in Space Settings → Secrets)
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+ if not HF_TOKEN:
17
+ raise RuntimeError("Missing HF_TOKEN env var! Please add it in your Space settings → Secrets.")
18
 
19
+ MODEL_ID = "google/gemma-3n-e2b-it"
 
 
20
 
21
+ # 2) Eagerly load the small bits (processor & tokenizer) so the UI starts fast
22
+ processor = AutoProcessor.from_pretrained(
23
+ MODEL_ID, trust_remote_code=True, token=HF_TOKEN
24
+ )
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_ID, trust_remote_code=True, token=HF_TOKEN
27
+ )
28
+
29
+ def generate_all_and_split():
30
+ """Called when the user clicks the button—loads full model, generates & saves files."""
31
+ # a) Lazy‑load the 8‑bit quantized model (heavy)
32
  model = AutoModelForImageTextToText.from_pretrained(
33
  MODEL_ID,
34
  trust_remote_code=True,
35
+ token=HF_TOKEN,
36
  load_in_8bit=True,
37
  device_map="auto"
38
  )
39
  device = next(model.parameters()).device
40
 
41
+ def to_soap(text: str) -> str:
 
42
  inputs = processor.apply_chat_template(
43
  [
44
  {"role":"system","content":[{"type":"text","text":"You are a medical AI assistant."}]},
 
53
  **inputs,
54
  max_new_tokens=400,
55
  do_sample=True,
 
56
  top_p=0.95,
57
+ temperature=0.1,
58
  pad_token_id=processor.tokenizer.eos_token_id
59
  )
60
+ prompt_len = inputs["input_ids"].shape[-1]
61
+ return processor.batch_decode(out[:, prompt_len:], skip_special_tokens=True)[0].strip()
62
 
63
+ # b) Generate 100 doc_notes + ground_truth SOAPs
64
  docs, gts = [], []
65
+ for i in range(1, 101):
66
  doc = to_soap("Generate a realistic, concise doctor's progress note for a single patient encounter.")
67
  docs.append(doc)
68
  gts.append(to_soap(doc))
69
+ if i % 20 == 0:
70
  torch.cuda.empty_cache()
71
 
72
+ # c) Split 70/30
73
+ df = pd.DataFrame({"doc_note": docs, "ground_truth_soap": gts})
74
+ train_df, test_df = train_test_split(df, test_size=0.3, random_state=42)
75
 
76
  os.makedirs("outputs", exist_ok=True)
77
 
78
+ # d) Inference on train → inference.tsv
79
  train_preds = [to_soap(d) for d in train_df["doc_note"]]
80
  inf = train_df.reset_index(drop=True).copy()
81
  inf["id"] = inf.index + 1
 
84
  "outputs/inference.tsv", sep="\t", index=False
85
  )
86
 
87
+ # e) Inference on test → eval.csv
88
  test_preds = [to_soap(d) for d in test_df["doc_note"]]
89
  pd.DataFrame({
90
+ "id": range(1, len(test_preds)+1),
91
  "predicted_soap": test_preds
92
  }).to_csv("outputs/eval.csv", index=False)
93
 
94
  return (
95
  "✅ Done!\n"
96
+ f" outputs/inference.tsv (70 rows with id, GT, pred)\n"
97
+ f" outputs/eval.csv (30 rows with id, pred)"
98
  )
99
 
100
+ # 3) Gradio UI—instant startup
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("## Gemma‑3n SOAP Generator 🩺")
103
+ btn = gr.Button("Generate & Save 100 Notes → 70/30 Split → inference & eval")
104
+ status = gr.Textbox(interactive=False, label="Status")
105
+ btn.click(fn=generate_all_and_split, inputs=None, outputs=status)
106
 
107
+ if __name__ == "__main__":
108
  demo.launch()