ga89tiy commited on
Commit
ca57734
·
1 Parent(s): d9d891d
README.md CHANGED
@@ -57,7 +57,8 @@ from LLAVA_Biovil.llava.model.builder import load_pretrained_model
57
  from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
58
 
59
  from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
60
- from utils import create_chest_xray_transform_for_inference
 
61
 
62
  def load_model_from_huggingface(repo_id):
63
  # Download model files
@@ -67,18 +68,31 @@ def load_model_from_huggingface(repo_id):
67
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
68
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
69
 
 
70
  return tokenizer, model, image_processor, context_len
71
 
72
- tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
73
 
74
 
75
  if __name__ == '__main__':
76
- # config = None
77
- # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
78
- # model_name = get_model_name_from_path(model_path)
 
 
 
 
 
 
 
79
  model.config.tokenizer_padding_side = "left"
80
 
81
- findings = "edema, pleural effusion" #TODO should these come from chexpert classifier? Or not needed for this demo/test?
 
 
 
 
 
 
82
 
83
  conv = conv_vicuna_v1.copy()
84
  REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
@@ -89,12 +103,6 @@ if __name__ == '__main__':
89
 
90
  # get the image
91
  vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
92
- sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
93
-
94
- response = requests.get(sample_img_path)
95
- image = Image.open(io.BytesIO(response.content))
96
- image = remap_to_uint8(np.array(image))
97
- image = Image.fromarray(image).convert("L")
98
  image_tensor = vis_transforms_biovil(image).unsqueeze(0)
99
 
100
  image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
 
57
  from LLAVA_Biovil.llava.conversation import SeparatorStyle, conv_vicuna_v1
58
 
59
  from LLAVA_Biovil.llava.constants import IMAGE_TOKEN_INDEX
60
+ from utils import create_chest_xray_transform_for_inference, init_chexpert_predictor
61
+
62
 
63
  def load_model_from_huggingface(repo_id):
64
  # Download model files
 
68
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
69
  model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
70
 
71
+
72
  return tokenizer, model, image_processor, context_len
73
 
 
74
 
75
 
76
  if __name__ == '__main__':
77
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
78
+
79
+ response = requests.get(sample_img_path)
80
+ image = Image.open(io.BytesIO(response.content))
81
+ image = remap_to_uint8(np.array(image))
82
+ image = Image.fromarray(image).convert("L")
83
+
84
+ tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="Chantal/RaDialog-interactive-radiology-report-generation")
85
+ cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
86
+
87
  model.config.tokenizer_padding_side = "left"
88
 
89
+ cp_image = cp_transforms(image)
90
+ logits = cp_model(cp_image[None].half().cuda())
91
+ preds_probs = torch.sigmoid(logits)
92
+ preds = preds_probs > 0.5
93
+ pred = preds[0].cpu().numpy()
94
+ findings = cp_class_names[pred].tolist()
95
+ findings = ', '.join(findings).lower().strip()
96
 
97
  conv = conv_vicuna_v1.copy()
98
  REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
 
103
 
104
  # get the image
105
  vis_transforms_biovil = create_chest_xray_transform_for_inference(512, center_crop_size=448)
 
 
 
 
 
 
106
  image_tensor = vis_transforms_biovil(image).unsqueeze(0)
107
 
108
  image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)
simple_test.py → example_code.py RENAMED
@@ -22,7 +22,7 @@ def load_model_from_huggingface(repo_id):
22
  model_path = Path(model_path)
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
- model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
26
 
27
 
28
  return tokenizer, model, image_processor, context_len
@@ -30,10 +30,7 @@ def load_model_from_huggingface(repo_id):
30
 
31
 
32
  if __name__ == '__main__':
33
- # config = None
34
- # model_path = "/home/guests/chantal_pellegrini/RaDialog_LLaVA/LLAVA/checkpoints/llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5/checkpoint-21000" #TODO hardcoded in huggingface repo probably
35
- # model_name = get_model_name_from_path(model_path)
36
- sample_img_path = "https://openi.nlm.nih.gov/imgs/512/10/10/CXR10_IM-0002-2001.png?keywords=Calcified%20Granuloma" #TODO find good image
37
 
38
  response = requests.get(sample_img_path)
39
  image = Image.open(io.BytesIO(response.content))
@@ -54,7 +51,7 @@ if __name__ == '__main__':
54
  findings = ', '.join(findings).lower().strip()
55
 
56
  conv = conv_vicuna_v1.copy()
57
- REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
58
  print("USER: ", REPORT_GEN_PROMPT)
59
  conv.append_message("USER", REPORT_GEN_PROMPT)
60
  conv.append_message("ASSISTANT", None)
@@ -85,6 +82,27 @@ if __name__ == '__main__':
85
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
86
  print("ASSISTANT: ", pred)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # add prediction to conversation
89
  conv.messages.pop()
90
  conv.append_message("ASSISTANT", pred)
@@ -108,4 +126,3 @@ if __name__ == '__main__':
108
 
109
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
110
  print("ASSISTANT: ", pred)
111
-
 
22
  model_path = Path(model_path)
23
 
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
25
+ model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, $
26
 
27
 
28
  return tokenizer, model, image_processor, context_len
 
30
 
31
 
32
  if __name__ == '__main__':
33
+ sample_img_path = "https://openi.nlm.nih.gov/imgs/512/294/3502/CXR3502_IM-1707-1001.png?keywords=Surgical%20Instruments,Cardiomegaly,Pulmonary%20Congestion,Diaphragm"
 
 
 
34
 
35
  response = requests.get(sample_img_path)
36
  image = Image.open(io.BytesIO(response.content))
 
51
  findings = ', '.join(findings).lower().strip()
52
 
53
  conv = conv_vicuna_v1.copy()
54
+ REPORT_GEN_PROMPT = f"<image>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predi$
55
  print("USER: ", REPORT_GEN_PROMPT)
56
  conv.append_message("USER", REPORT_GEN_PROMPT)
57
  conv.append_message("ASSISTANT", None)
 
82
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
83
  print("ASSISTANT: ", pred)
84
 
85
+ # add prediction to conversation
86
+ conv.messages.pop()
87
+ conv.append_message("ASSISTANT", pred)
88
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89
+ stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
90
+
91
+ # generate a report
92
+ with torch.inference_mode():
93
+ output_ids = model.generate(
94
+ input_ids,
95
+ images=image_tensor,
96
+ do_sample=False,
97
+ use_cache=True,
98
+ max_new_tokens=300,
99
+ stopping_criteria=[stopping_criteria],
100
+ pad_token_id=tokenizer.pad_token_id
101
+ )
102
+
103
+ pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
104
+ print("ASSISTANT: ", pred)
105
+
106
  # add prediction to conversation
107
  conv.messages.pop()
108
  conv.append_message("ASSISTANT", pred)
 
126
 
127
  pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
128
  print("ASSISTANT: ", pred)
 
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ huggingface-hub==0.19.4
3
+ timm==0.6.13
4
+ transformers==4.31.0
5
+ scikit-image==0.18.1
6
+ peft==0.4.0
7
+ pytorch_lightning==1.6.5
8
+ chardet
9
+ scikit-learn==1.2.2
10
+ sentencepiece==0.1.99
requirements.txt.py DELETED
File without changes