bryanmildort commited on
Commit
a8a1c1f
·
1 Parent(s): a9f3afd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import streamlit as st
2
 
3
  def summarize_function(notes):
4
- gen_text = pipe(notes, max_length=(len(notes.split(' '))*2*1.225), temperature=0.8, num_return_sequences=1, top_p=0.2)[0]['generated_text']
 
 
 
5
  st.write('Summary: ')
6
- return gen_text[len(notes):]
7
 
8
  st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer 0.1v</h1>", unsafe_allow_html=True)
9
  st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
@@ -17,8 +20,8 @@ import torch
17
  # device_map = infer_auto_device_map(model, dtype="float16")
18
  # st.write(device_map)
19
 
20
- model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes_summary", low_cpu_mem_usage=True).cuda()
21
- tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes_summary")
22
  # model = model.to(device)
23
 
24
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
1
  import streamlit as st
2
 
3
  def summarize_function(notes):
4
+ gen_text = pipe(notes, max_length=(len(notes.split(' '))*2*1.225), temperature=0.8, num_return_sequences=1, top_p=0.2)[0]['generated_text'][len(notes):]
5
+ for i in range(len(gen_text)):
6
+ if gen_text[-i-8:].startswith('[Notes]:'):
7
+ gen_text = gen_text[:-i-8]
8
  st.write('Summary: ')
9
+ return gen_text
10
 
11
  st.markdown("<h1 style='text-align: center; color: #489DDB;'>GPT Clinical Notes Summarizer 0.1v</h1>", unsafe_allow_html=True)
12
  st.markdown("<h6 style='text-align: center; color: #489DDB;'>by Bryan Mildort</h1>", unsafe_allow_html=True)
 
20
  # device_map = infer_auto_device_map(model, dtype="float16")
21
  # st.write(device_map)
22
 
23
+ model = AutoModelForCausalLM.from_pretrained("bryanmildort/gpt_neo_notes", low_cpu_mem_usage=True).cuda()
24
+ tokenizer = AutoTokenizer.from_pretrained("bryanmildort/gpt_neo_notes")
25
  # model = model.to(device)
26
 
27
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)