Spaces:
Sleeping
Sleeping
File size: 2,738 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
from hashlib import sha1
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
from atoms_detection.dl_detection import DLDetection
from atoms_detection.dataset import CoordinatesDataset
from utils.constants import Split, ModelArgs
from utils.paths import PT_DATASET, PREDS_PATH, DETECTION_PATH,LANDS_VIS_PATH
threshold = 0.89
extension_name = "replicate"
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}_{threshold}")
inference_cache_path = os.path.join(PREDS_PATH, os.path.basename(detections_path))
def get_pred_map(img_filename: str) -> np.ndarray:
img_hash = sha1(img_filename.encode()).hexdigest()
prediciton_cache = os.path.join(inference_cache_path, f"{img_hash}.npy")
if not os.path.exists(prediciton_cache):
detection = DLDetection(
model_name=ModelArgs.BASICCNN,
ckpt_filename="/home/fpares/PycharmProjects/stem_atoms/models/basic_replicate.ckpt",
dataset_csv="/home/fpares/PycharmProjects/stem_atoms/dataset/Coordinate_image_pairs.csv",
threshold=threshold,
detections_path=detections_path
)
img = DLDetection.open_image(image_path)
pred_map = detection.image_to_pred_map(img)
np.save(prediciton_cache, pred_map)
else:
pred_map = np.load(prediciton_cache)
return pred_map
def short_proj():
return np.dot(Axes3D.get_proj(ax), scale)
if not os.path.exists(LANDS_VIS_PATH):
os.makedirs(LANDS_VIS_PATH)
coordinates_dataset = CoordinatesDataset(PT_DATASET)
for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST):
pred_map = get_pred_map(image_path)
"""
Scaling is done from here...
"""
x_scale = 1
y_scale = 1
z_scale = 0.1
scale = np.diag([x_scale, y_scale, z_scale, 1.0])
scale = scale * (1.0 / scale.max())
scale[3, 3] = 1.0
X = np.arange(0, 512, 1)
Y = np.arange(0, 512, 1)
X, Y = np.meshgrid(X, Y)
# fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
fig = plt.figure(figsize=(10, 10))
ax = fig.gca(projection='3d')
ax.get_proj = short_proj
surf = ax.plot_surface(X, Y, pred_map, cmap=cm.coolwarm,
rstride=2, cstride=2,
linewidth=0.2, antialiased=True)
ax.set_axis_off()
img_name = os.path.splitext(os.path.basename(image_path))[0]
landscape_output_path = os.path.join(LANDS_VIS_PATH, f"{img_name}_landscape_{extension_name}_{threshold}.png")
plt.savefig(landscape_output_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
# plt.show()
|