Spaces:
Sleeping
Sleeping
Delete Vit_NN_app
Browse files- Vit_NN_app +0 -1
Vit_NN_app
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"mount_file_id":"1a0xxjADeOQpTPyfGzsiiTQPHv9PErOfL","authorship_tag":"ABX9TyNQRXuzLf07TW5H2pWsmV6V"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"UmuuIaySGwRC"},"outputs":[],"source":["import torch\n","import torchvision.transforms as transforms\n","import torchvision.models as models\n","import gradio as gr\n","import numpy as np\n","from PIL import Image\n","from sklearn.preprocessing import StandardScaler # Required for feature scaling\n","import joblib # To load the scaler\n","\n","# Set device\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# Load trained ViT model\n","vit_model = models.vit_b_16(pretrained=False)\n","vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification\n","\n","# Load ViT model weights\n","vit_model_path = \"/content/drive/MyDrive/ViT_BCC/vit_bc\"\n","vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))\n","vit_model.to(device)\n","vit_model.eval()\n","\n","# Define ViT image transformations\n","transform = transforms.Compose([\n"," transforms.Resize((224, 224)),\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n","])\n","\n","# Class labels\n","class_names = [\"Benign\", \"Malignant\"]\n","\n","# Load trained Neural Network model (ensure this is properly trained)\n","nn_model_path = \"/content/drive/MyDrive/NN_BCC/nn_bc.pth\" # Update path\n","nn_model = torch.load(nn_model_path, map_location=device) # Assuming a PyTorch model\n","nn_model.to(device)\n","nn_model.eval()\n","\n","# Load scaler for feature normalization\n","scaler_path = \"/content/drive/MyDrive/NN_BCC/scaler.pkl\" # Update path\n","scaler = joblib.load(scaler_path) # Load pre-fitted scaler\n","\n","# Define feature names for NN model\n","feature_names = [\n"," \"Mean Radius\", \"Mean Texture\", \"Mean Perimeter\", \"Mean Area\", \"Mean Smoothness\",\n"," \"Mean Compactness\", \"Mean Concavity\", \"Mean Concave Points\", \"Mean Symmetry\", \"Mean Fractal Dimension\",\n"," \"SE Radius\", \"SE Texture\", \"SE Perimeter\", \"SE Area\", \"SE Smoothness\",\n"," \"SE Compactness\", \"SE Concavity\", \"SE Concave Points\", \"SE Symmetry\", \"SE Fractal Dimension\",\n"," \"Worst Radius\", \"Worst Texture\", \"Worst Perimeter\", \"Worst Area\", \"Worst Smoothness\",\n"," \"Worst Compactness\", \"Worst Concavity\", \"Worst Concave Points\", \"Worst Symmetry\", \"Worst Fractal Dimension\"\n","]\n","\n","def classify(model_choice, image=None, *features):\n"," \"\"\"Classify using ViT (image) or NN (features).\"\"\"\n"," if model_choice == \"ViT\":\n"," if image is None:\n"," return \"Please upload an image for ViT classification.\"\n"," image = image.convert(\"RGB\") # Ensure RGB format\n"," input_tensor = transform(image).unsqueeze(0).to(device) # Preprocess image\n","\n"," with torch.no_grad():\n"," output = vit_model(input_tensor)\n"," predicted_class = torch.argmax(output, dim=1).item()\n","\n"," return class_names[predicted_class]\n","\n"," elif model_choice == \"Neural Network\":\n"," if any(f is None for f in features):\n"," return \"Please enter all 30 numerical features.\"\n","\n"," # Convert input features to NumPy array\n"," input_data = np.array(features).reshape(1, -1)\n","\n"," # Standardize using pre-trained scaler\n"," input_data_std = scaler.transform(input_data)\n","\n"," # Convert to tensor and run prediction\n"," input_tensor = torch.tensor(input_data_std, dtype=torch.float32).to(device)\n","\n"," with torch.no_grad():\n"," output = nn_model(input_tensor)\n"," predicted_class = torch.argmax(output, dim=1).item()\n","\n"," return class_names[predicted_class]\n","\n","# Define Gradio UI components\n","model_selector = gr.Radio([\"ViT\", \"Neural Network\"], label=\"Choose Model\")\n","image_input = gr.Image(type=\"pil\", label=\"Upload Mammogram Image\")\n","\n","# Feature inputs labeled correctly\n","feature_inputs = [gr.Number(label=feature_names[i]) for i in range(30)]\n","\n","# Gradio Interface\n","iface = gr.Interface(\n"," fn=classify,\n"," inputs=[model_selector, image_input] + feature_inputs, # Image + Feature inputs\n"," outputs=\"text\",\n"," title=\"Breast Cancer Classification\",\n"," description=\"Choose between ViT (image-based) and Neural Network (feature-based) classification.\"\n",")\n","\n","iface.launch()"]},{"cell_type":"code","source":["!pip install gradio"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"sKt2mlK9IAFk","executionInfo":{"status":"ok","timestamp":1740235646745,"user_tz":-330,"elapsed":5271,"user":{"displayName":"Sneha T S","userId":"17838132366962769762"}},"outputId":"9e551b86-6bff-43f0-9372-ee721294aa67"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: gradio in /usr/local/lib/python3.11/dist-packages (5.17.1)\n","Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (23.2.1)\n","Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.7.1)\n","Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.115.8)\n","Requirement already satisfied: ffmpy in /usr/local/lib/python3.11/dist-packages (from gradio) (0.5.0)\n","Requirement already satisfied: gradio-client==1.7.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (1.7.1)\n","Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.28.1)\n","Requirement already satisfied: huggingface-hub>=0.28.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.28.1)\n","Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.1.5)\n","Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.1.5)\n","Requirement already satisfied: numpy<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (1.26.4)\n","Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.10.15)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from gradio) (24.2)\n","Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.2.2)\n","Requirement already satisfied: pillow<12.0,>=8.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (11.1.0)\n","Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.10.6)\n","Requirement already satisfied: pydub in /usr/local/lib/python3.11/dist-packages (from gradio) (0.25.1)\n","Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.0.20)\n","Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (6.0.2)\n","Requirement already satisfied: ruff>=0.9.3 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.9.7)\n","Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.1.6)\n","Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.10.0)\n","Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.45.3)\n","Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.13.2)\n","Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.15.1)\n","Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (4.12.2)\n","Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.34.0)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.7.1->gradio) (2024.10.0)\n","Requirement already satisfied: websockets<15.0,>=10.0 in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.7.1->gradio) (14.2)\n","Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio) (3.10)\n","Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)\n","Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio) (2025.1.31)\n","Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio) (1.0.7)\n","Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (3.17.0)\n","Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (2.32.3)\n","Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (4.67.1)\n","Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n","Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n","Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio) (0.7.0)\n","Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio) (2.27.2)\n","Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (8.1.8)\n","Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (1.5.4)\n","Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (13.9.4)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)\n","Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)\n","Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.18.0)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (3.4.1)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (2.3.0)\n","Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)\n"]}]},{"cell_type":"code","source":["import torch\n","import torchvision.transforms as transforms\n","import torchvision.models as models\n","import gradio as gr\n","import numpy as np\n","import tensorflow as tf\n","from PIL import Image\n","from sklearn.preprocessing import StandardScaler # Required for feature scaling\n","import joblib # To load the scaler\n","\n","# Set device for ViT model (PyTorch)\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","# Load trained ViT model (PyTorch)\n","vit_model = models.vit_b_16(pretrained=False)\n","vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification\n","\n","# Load ViT model weights\n","vit_model_path = \"/content/drive/MyDrive/ViT_BCC/vit_bc\"\n","vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))\n","vit_model.to(device)\n","vit_model.eval()\n","\n","# Define ViT image transformations\n","transform = transforms.Compose([\n"," transforms.Resize((224, 224)),\n"," transforms.ToTensor(),\n"," transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n","])\n","\n","# Class labels\n","class_names = [\"Benign\", \"Malignant\"]\n","\n","# Load trained Neural Network model (TensorFlow/Keras)\n","nn_model_path = \"/content/drive/MyDrive/Breast_Cancer_Prediction_2024/DIR_NN_BC/my_NN_BC_model.keras\" # Ensure the correct path\n","nn_model = tf.keras.models.load_model(nn_model_path)\n","\n","# Load scaler for feature normalization\n","scaler_path = \"/content/drive/MyDrive/Breast_Cancer_Prediction_2024/DIR_NN_BC/nn_bc_scaler.pkl\" # Update path\n","scaler = joblib.load(scaler_path) # Load pre-fitted scaler\n","\n","# Define feature names for NN model\n","feature_names = [\n"," \"Mean Radius\", \"Mean Texture\", \"Mean Perimeter\", \"Mean Area\", \"Mean Smoothness\",\n"," \"Mean Compactness\", \"Mean Concavity\", \"Mean Concave Points\", \"Mean Symmetry\", \"Mean Fractal Dimension\",\n"," \"SE Radius\", \"SE Texture\", \"SE Perimeter\", \"SE Area\", \"SE Smoothness\",\n"," \"SE Compactness\", \"SE Concavity\", \"SE Concave Points\", \"SE Symmetry\", \"SE Fractal Dimension\",\n"," \"Worst Radius\", \"Worst Texture\", \"Worst Perimeter\", \"Worst Area\", \"Worst Smoothness\",\n"," \"Worst Compactness\", \"Worst Concavity\", \"Worst Concave Points\", \"Worst Symmetry\", \"Worst Fractal Dimension\"\n","]\n","\n","def classify(model_choice, image=None, *features):\n"," \"\"\"Classify using ViT (image) or NN (features).\"\"\"\n"," if model_choice == \"ViT\":\n"," if image is None:\n"," return \"Please upload an image for ViT classification.\"\n"," image = image.convert(\"RGB\") # Ensure RGB format\n"," input_tensor = transform(image).unsqueeze(0).to(device) # Preprocess image\n","\n"," with torch.no_grad():\n"," output = vit_model(input_tensor)\n"," predicted_class = torch.argmax(output, dim=1).item()\n","\n"," return class_names[predicted_class]\n","\n"," elif model_choice == \"Neural Network\":\n"," if any(f is None for f in features):\n"," return \"Please enter all 30 numerical features.\"\n","\n"," # Convert input features to NumPy array\n"," input_data = np.array(features).reshape(1, -1)\n","\n"," # Standardize using pre-trained scaler\n"," input_data_std = scaler.transform(input_data)\n","\n"," # Run prediction using TensorFlow model\n"," prediction = nn_model.predict(input_data_std)\n"," predicted_class = np.argmax(prediction)\n","\n"," return class_names[predicted_class]\n","\n","# Define Gradio UI components\n","model_selector = gr.Radio([\"ViT\", \"Neural Network\"], label=\"Choose Model\")\n","image_input = gr.Image(type=\"pil\", label=\"Upload Mammogram Image\")\n","\n","# Feature inputs labeled correctly\n","feature_inputs = [gr.Number(label=feature_names[i]) for i in range(30)]\n","\n","# Gradio Interface\n","iface = gr.Interface(\n"," fn=classify,\n"," inputs=[model_selector, image_input] + feature_inputs, # Image + Feature inputs\n"," outputs=\"text\",\n"," title=\"Breast Cancer Classification\",\n"," description=\"Choose between ViT (image-based) and Neural Network (feature-based) classification.\"\n",")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":680},"id":"c4CporNnGzJQ","executionInfo":{"status":"ok","timestamp":1740236393420,"user_tz":-330,"elapsed":6799,"user":{"displayName":"Sneha T S","userId":"17838132366962769762"}},"outputId":"9beabe99-a141-477b-f71d-5d88c7a42b81"},"execution_count":7,"outputs":[{"output_type":"stream","name":"stderr","text":["<ipython-input-7-50666140025a>:20: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n"," vit_model.load_state_dict(torch.load(vit_model_path, map_location=device))\n"]},{"output_type":"stream","name":"stdout","text":["Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n","\n","Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n","* Running on public URL: https://f788e88ec7f086353b.gradio.live\n","\n","This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n"]},{"output_type":"display_data","data":{"text/plain":["<IPython.core.display.HTML object>"],"text/html":["<div><iframe src=\"https://f788e88ec7f086353b.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"]},"metadata":{}},{"output_type":"execute_result","data":{"text/plain":[]},"metadata":{},"execution_count":7}]},{"cell_type":"code","source":["# launch app\n","iface.launch()"],"metadata":{"id":"1FT-VkHTH-Nx"},"execution_count":null,"outputs":[]}]}
|
|
|
|