Singularity666 commited on
Commit
925e018
·
1 Parent(s): 826a746

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -100
app.py CHANGED
@@ -4,15 +4,8 @@ import pandas as pd
4
  import torch
5
  from PIL import Image
6
  import numpy as np
7
- from main import predict_caption, CLIPModel, get_text_embeddings
8
- import openai
9
- import base64
10
- from docx import Document
11
- from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
12
- from io import BytesIO
13
- import re
14
 
15
- openai.api_key = "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"
16
 
17
  st.markdown(
18
  """
@@ -20,35 +13,12 @@ st.markdown(
20
  body {
21
  background-color: transparent;
22
  }
23
- .container {
24
- display: flex;
25
- justify-content: center;
26
- align-items: center;
27
- background-color: rgba(255, 255, 255, 0.7);
28
- border-radius: 15px;
29
- padding: 20px;
30
- }
31
- .stApp {
32
- background-color: transparent;
33
- }
34
- .stText, .stMarkdown, .stTextInput>label, .stButton>button>span {
35
- color: #1c1c1c !important; /* Set the dark text color for text elements */
36
- }
37
- .stButton>button>span {
38
- color: initial !important; /* Reset the text color for the 'Generate Caption' button */
39
- }
40
- .stMarkdown h1, .stMarkdown h2 {
41
- color: #ff6b81 !important; /* Set the text color of h1 and h2 elements to soft red-pink */
42
- font-weight: bold; /* Set the font weight to bold */
43
- border: 2px solid #ff6b81; /* Add a bold border around the headers */
44
- padding: 10px; /* Add padding to the headers */
45
- border-radius: 5px; /* Add border-radius to the headers */
46
- }
47
  </style>
48
  """,
49
  unsafe_allow_html=True,
50
  )
51
 
 
52
  device = torch.device("cpu")
53
 
54
  testing_df = pd.read_csv("testing_df.csv")
@@ -56,85 +26,26 @@ model = CLIPModel().to(device)
56
  model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
57
  text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
58
 
59
- def download_link(content, filename, link_text):
60
- b64 = base64.b64encode(content).decode()
61
- href = f'<a href="data:application/octet-stream;base64,{b64}" download="{filename}">{link_text}</a>'
62
- return href
63
 
64
- def show_predicted_caption(image, top_k=1):
65
  matches = predict_caption(
66
  image, model, text_embeddings, testing_df["caption"]
67
- )[:top_k]
68
- cleaned_matches = [re.sub(r'\s\(ROCO_\d+\)', '', match) for match in matches] # Add this line to clean the matches
69
- return cleaned_matches # Return the cleaned_matches instead of matches
70
-
71
- def generate_radiology_report(prompt):
72
- response = openai.Completion.create(
73
- engine="text-davinci-003",
74
- prompt=prompt,
75
- max_tokens=800,
76
- n=1,
77
- stop=None,
78
- temperature=1,
79
- )
80
- report = response.choices[0].text.strip()
81
- # Remove reference string from the report
82
- report = re.sub(r'\(ROCO_\d+\)', '', report).strip()
83
- return report
84
-
85
 
86
- def save_as_docx(text, filename):
87
- document = Document()
88
- document.add_paragraph(text)
89
- with BytesIO() as output:
90
- document.save(output)
91
- output.seek(0)
92
- return output.getvalue()
93
 
94
- st.title("RadiXGPT: An Evolution of machine doctors towards Radiology")
95
-
96
- # Collect user's personal information
97
- st.subheader("Personal Information")
98
- first_name = st.text_input("First Name")
99
- last_name = st.text_input("Last Name")
100
- age = st.number_input("Age", min_value=0, max_value=120, value=25, step=1)
101
- gender = st.selectbox("Gender", ["Male", "Female", "Other"])
102
-
103
- st.write("Upload Scan to get Radiological Report:")
104
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
105
  if uploaded_file is not None:
106
  image = Image.open(uploaded_file)
107
  st.image(image, caption="Uploaded Image", use_column_width=True)
108
  st.write("")
109
 
110
- if st.button("Generate Report"):
111
- with st.spinner("Generating Report..."):
112
  image_np = np.array(image)
113
- caption = show_predicted_caption(image_np)[0]
114
-
115
  st.success(f"Caption: {caption}")
116
 
117
- # Generate the radiology report
118
- radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {caption}")
119
-
120
- # Add personal information to the radiology report
121
- radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{radiology_report}"
122
-
123
- st.header("Radiology Report")
124
- st.write(radiology_report_with_personal_info)
125
- st.markdown(download_link(save_as_docx(radiology_report_with_personal_info, "radiology_report.docx"), "radiology_report.docx", "Download Report as DOCX"), unsafe_allow_html=True)
126
-
127
- feedback_options = ["Satisfied", "Not Satisfied"]
128
- selected_feedback = st.radio("Please provide feedback on the generated report:", feedback_options)
129
-
130
- if selected_feedback == "Not Satisfied":
131
- if st.button("Regenerate Report"):
132
- with st.spinner("Regenerating report..."):
133
- alternative_caption = get_alternative_caption(image_np, model, text_embeddings, testing_df["caption"])
134
- regenerated_radiology_report = generate_radiology_report(f"Write Complete Radiology Report for this with clinical info, subjective, Assessment, Finding, Impressions, Conclusion and more in proper order : {alternative_caption}")
135
-
136
- regenerated_radiology_report_with_personal_info = f"Patient Name: {first_name} {last_name}\nAge: {age}\nGender: {gender}\n\n{regenerated_radiology_report}"
137
-
138
- st.header("Regenerated Radiology Report")
139
- st.write(regenerated_radiology_report_with_personal_info)
140
- st.markdown(download_link(save_as_docx(regenerated_radiology_report_with_personal_info, "regenerated_radiology_report.docx"), "regenerated_radiology_report.docx", "Download Regenerated Report as DOCX"), unsafe_allow_html=True)
 
4
  import torch
5
  from PIL import Image
6
  import numpy as np
7
+ from main import predict_caption, CLIPModel , get_text_embeddings
 
 
 
 
 
 
8
 
 
9
 
10
  st.markdown(
11
  """
 
13
  body {
14
  background-color: transparent;
15
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  </style>
17
  """,
18
  unsafe_allow_html=True,
19
  )
20
 
21
+
22
  device = torch.device("cpu")
23
 
24
  testing_df = pd.read_csv("testing_df.csv")
 
26
  model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu')))
27
  text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device)
28
 
 
 
 
 
29
 
30
+ def show_predicted_caption(image):
31
  matches = predict_caption(
32
  image, model, text_embeddings, testing_df["caption"]
33
+ )[0]
34
+ return matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ st.title("Medical Image Captioning")
37
+ st.write("Upload an image to get a caption:")
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
39
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
40
  if uploaded_file is not None:
41
  image = Image.open(uploaded_file)
42
  st.image(image, caption="Uploaded Image", use_column_width=True)
43
  st.write("")
44
 
45
+ if st.button("Generate Caption"):
46
+ with st.spinner("Generating caption..."):
47
  image_np = np.array(image)
48
+ caption = show_predicted_caption(image_np)
 
49
  st.success(f"Caption: {caption}")
50
 
51
+