{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edcb7d2-53dc-4170-9f2f-619c0da0ae4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f839c8fb-b018-4ab6-86a9-7d5bf7883b45",
   "metadata": {},
   "source": [
    "# Load OpenPhenom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84b9324d-fde9-4c43-bc5a-eb66cdb4f891",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model directly\n",
    "from huggingface_mae import MAEModel\n",
    "open_phenom = MAEModel.from_pretrained(\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57d918c5-78de-4b36-9f46-4652c5da93f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "open_phenom.eval()\n",
    "cuda_available = torch.cuda.is_available()\n",
    "if cuda_available:\n",
    "    open_phenom.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c89d82d-5365-4492-b496-adb3bbd71b32",
   "metadata": {},
   "source": [
    "# Load Rxrx3-core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deeff3a8-db67-4905-a7e9-c43aad614a84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "rxrx3_core = load_dataset(\"recursionpharma/rxrx3-core\")['train']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f2226ce-9415-4dd8-932e-54e4e1bd8c1a",
   "metadata": {},
   "source": [
    "# Infernce loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1218ab-f9cd-413b-9228-c1146df978be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_path_to_well_id(path_str):\n",
    "    \n",
    "    return path_str.split('_')[0].replace('/','_').replace('Plate','')\n",
    "    \n",
    "def collate_rxrx3_core(batch):\n",
    "    \n",
    "    images = np.stack([np.array(i['jp2']) for i in batch]).reshape(-1,6,512,512)\n",
    "    images = np.vstack([patch_image(i) for i in images]) # convert to 4 256x256 patches\n",
    "    images = torch.from_numpy(images)\n",
    "    well_ids = [convert_path_to_well_id(i['__key__']) for i in batch[::6]]\n",
    "    return images, well_ids\n",
    "\n",
    "def iter_border_patches(width, height, patch_size):\n",
    "    \n",
    "    x_start, x_end, y_start, y_end = (0, width, 0, height)\n",
    "\n",
    "    for x in range(x_start, x_end - patch_size + 1, patch_size):\n",
    "        for y in range(y_start, y_end - patch_size + 1, patch_size):\n",
    "            yield x, y\n",
    "\n",
    "def patch_image(image_array, patch_size=256):\n",
    "    \n",
    "    _, width, height = image_array.shape\n",
    "    output_patches = []\n",
    "    patch_count = 0\n",
    "    for x, y in iter_border_patches(width, height, patch_size):\n",
    "        patch = image_array[:, y : y + patch_size, x : x + patch_size].copy()\n",
    "        output_patches.append(patch)\n",
    "    \n",
    "    output_patches = np.stack(output_patches)\n",
    "    \n",
    "    return output_patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de308003-bcfc-4b59-9715-dd884b9b2536",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert to PyTorch DataLoader\n",
    "batch_size = 128\n",
    "num_workers = 4\n",
    "rxrx3_core_dataloader = DataLoader(rxrx3_core, batch_size=batch_size*6, shuffle=False, \n",
    "                                   collate_fn=collate_rxrx3_core, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e3ea6c2-d1aa-4e20-a175-d72ea636153e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Inference loop\n",
    "num_features = 384\n",
    "n_crops = 4\n",
    "well_ids = []\n",
    "emb_ind = 0\n",
    "embeddings = np.zeros(\n",
    "    ((len(rxrx3_core_dataloader.dataset)//6), num_features), dtype=np.float32\n",
    ")\n",
    "forward_pass_counter = 0\n",
    "\n",
    "for imgs, batch_well_ids in rxrx3_core_dataloader:\n",
    "\n",
    "    if cuda_available:\n",
    "        with torch.amp.autocast(\"cuda\"), torch.no_grad():\n",
    "            latent = open_phenom.predict(imgs.cuda())\n",
    "    else:\n",
    "        latent = open_phenom.predict(imgs)\n",
    "    \n",
    "    latent = latent.view(-1, n_crops, num_features).mean(dim=1)  # average over 4 256x256 crops per image\n",
    "    embeddings[emb_ind : (emb_ind + len(latent))] = latent.detach().cpu().numpy()\n",
    "    well_ids.extend(batch_well_ids)\n",
    "\n",
    "    emb_ind += len(latent)\n",
    "    forward_pass_counter += 1\n",
    "    if forward_pass_counter % 5 == 0:\n",
    "        print(f\"forward pass {forward_pass_counter} of {len(rxrx3_core_dataloader)} done, wells inferenced {emb_ind}\")\n",
    "\n",
    "embedding_df = embeddings[:emb_ind]\n",
    "embedding_df = pd.DataFrame(embedding_df)\n",
    "embedding_df.columns = [f\"feature_{i}\" for i in range(num_features)]\n",
    "embedding_df['well_id'] = well_ids\n",
    "embedding_df = embedding_df[['well_id']+[f\"feature_{i}\" for i in range(num_features)]]\n",
    "embedding_df.to_parquet('OpenPhenom_rxrx3-core_embeddings.parquet')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "photo2",
   "language": "python",
   "name": "photo2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}