"
+ "source": [
+ "!mkdir static/\n",
+ "!mkdir static/training_data\n",
+ "!curl https://python.langchain.com/docs/tutorials/rag/ -o static/training_data/langchain_rag_tutorial.html"
]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import wandb\n",
- "wandb.init(mode=\"disabled\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "w1jVcY__pEC7",
+ "outputId": "94d2ef01-7c66-4836-8c9c-be06a83fe65c"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "1"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ],
+ "source": [
+ "from langchain_community.document_loaders import DirectoryLoader\n",
+ "from langchain_community.document_loaders import BSHTMLLoader\n",
+ "\n",
+ "path = \"static/training_data/\"\n",
+ "text_loader = DirectoryLoader(path, glob=\"*.html\", loader_cls=BSHTMLLoader)\n",
+ "docs = text_loader.load()\n",
+ "len(docs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "m6-krgUrpEC7",
+ "outputId": "c566f00e-9842-48b8-f84a-b9bb63516444"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "81"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 8
+ }
+ ],
+ "source": [
+ "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
+ "\n",
+ "text_splitter = RecursiveCharacterTextSplitter(\n",
+ " chunk_size = 750,\n",
+ " chunk_overlap = 20,\n",
+ " length_function = len\n",
+ ")\n",
+ "training_documents = text_splitter.split_documents(text_loader.load())\n",
+ "len(training_documents)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "JSIE4CqYpEC7"
+ },
+ "outputs": [],
+ "source": [
+ "import uuid\n",
+ "\n",
+ "id_set = set()\n",
+ "\n",
+ "for document in training_documents:\n",
+ " id = str(uuid.uuid4())\n",
+ " while id in id_set:\n",
+ " id = uuid.uuid4()\n",
+ " id_set.add(id)\n",
+ " document.metadata[\"id\"] = id"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "-lPkNLzcpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "# break up training documents into training, validation, and test sets\n",
+ "import random\n",
+ "\n",
+ "# set seed for reproducibility\n",
+ "random.seed(42)\n",
+ "\n",
+ "random.shuffle(training_documents)\n",
+ "\n",
+ "training_split_documents = training_documents[:int(0.8 * len(training_documents))]\n",
+ "val_split_documents = training_documents[int(0.8 * len(training_documents)):int(0.9 * len(training_documents))]\n",
+ "test_split_documents = training_documents[int(0.9 * len(training_documents)):]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "5ckUQad1pEC8"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "nNS_UDSOpEC8"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "3q7YaP_lpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "from langchain_openai import ChatOpenAI\n",
+ "\n",
+ "qa_chat_model = ChatOpenAI(\n",
+ " model=\"gpt-4o-mini\",\n",
+ " temperature=0\n",
+ ")\n",
+ "\n",
+ "from langchain_core.prompts import ChatPromptTemplate\n",
+ "\n",
+ "qa_prompt = \"\"\"\\\n",
+ "Given the following context, you must generate questions based on only the provided context.\n",
+ "\n",
+ "You are to generate {n_questions} questions which should be provided in the following format:\n",
+ "\n",
+ "1. QUESTION #1\n",
+ "2. QUESTION #2\n",
+ "...\n",
+ "\n",
+ "Context:\n",
+ "{context}\n",
+ "\"\"\"\n",
+ "\n",
+ "qa_prompt_template = ChatPromptTemplate.from_template(qa_prompt)\n",
+ "question_generation_chain = qa_prompt_template | qa_chat_model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "id": "EocviZT0pEC8"
+ },
+ "outputs": [],
+ "source": [
+ "import tqdm\n",
+ "\n",
+ "async def create_questions(documents, n_questions):\n",
+ " questions = {}\n",
+ " contexts = {}\n",
+ " for document in documents:\n",
+ " question = await question_generation_chain.ainvoke({\"context\": document.page_content, \"n_questions\": n_questions})\n",
+ " questions[document.metadata[\"id\"]] = question\n",
+ " contexts[document.metadata[\"id\"]] = [document.metadata[\"id\"]]\n",
+ " return questions, contexts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "id": "3QBN1GDUpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "training_questions, training_relevant_contexts = await create_questions(training_split_documents, 2)\n",
+ "val_questions, val_relevant_contexts = await create_questions(val_split_documents, 2)\n",
+ "test_questions, test_relevant_contexts = await create_questions(test_split_documents, 2)"
+ ]
+ },
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n"
- ]
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "id": "kEfNTkqXpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "\n",
+ "training_corpus = {train_item.metadata[\"id\"] : train_item.page_content for train_item in training_split_documents}\n",
+ "\n",
+ "# Convert AIMessage objects to their string content\n",
+ "training_questions_serializable = {k: v.content for k, v in training_questions.items()}\n",
+ "\n",
+ "train_dataset = {\n",
+ " \"questions\": training_questions_serializable,\n",
+ " \"relevant_contexts\": training_relevant_contexts,\n",
+ " \"corpus\": training_corpus\n",
+ "}\n",
+ "\n",
+ "with open(\"static/training_data/training_dataset.jsonl\", \"w\") as f:\n",
+ " json.dump(train_dataset, f)\n",
+ "\n",
+ "val_corpus = {val_item.metadata[\"id\"] : val_item.page_content for val_item in val_split_documents}\n",
+ "\n",
+ "# Convert AIMessage objects to their string content\n",
+ "val_questions_serializable = {k: v.content for k, v in val_questions.items()}\n",
+ "\n",
+ "val_dataset = {\n",
+ " \"questions\": val_questions_serializable,\n",
+ " \"relevant_contexts\": val_relevant_contexts,\n",
+ " \"corpus\": val_corpus\n",
+ "}\n",
+ "\n",
+ "with open(\"static/training_data/val_dataset.jsonl\", \"w\") as f:\n",
+ " json.dump(val_dataset, f)\n",
+ "\n",
+ "test_corpus = {test_item.metadata[\"id\"] : test_item.page_content for test_item in test_split_documents}\n",
+ "\n",
+ "# Convert AIMessage objects to their string content\n",
+ "test_questions_serializable = {k: v.content for k, v in test_questions.items()}\n",
+ "\n",
+ "test_dataset = {\n",
+ " \"questions\": test_questions_serializable,\n",
+ " \"relevant_contexts\": test_relevant_contexts,\n",
+ " \"corpus\": test_corpus\n",
+ "}\n",
+ "\n",
+ "with open(\"static/training_data/test_dataset.jsonl\", \"w\") as f:\n",
+ " json.dump(test_dataset, f)"
+ ]
},
{
- "data": {
- "text/html": [
- "\n",
- " \n",
- " \n",
- "
\n",
- " [ 2/40 : < :, Epoch 0.25/10]\n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " Step \n",
- " Training Loss \n",
- " Validation Loss \n",
- " \n",
- " \n",
- " \n",
- " \n",
- "
"
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "id": "abFeG5OcpEC8"
+ },
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "id": "asyH4PlYpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "from sentence_transformers import SentenceTransformer\n",
+ "\n",
+ "model_id = \"Snowflake/snowflake-arctic-embed-l\"\n",
+ "model = SentenceTransformer(model_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "id": "0hYIXL4KpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "from torch.utils.data import Dataset\n",
+ "from sentence_transformers import InputExample\n",
+ "\n",
+ "corpus = train_dataset['corpus']\n",
+ "queries = train_dataset['questions']\n",
+ "relevant_docs = train_dataset['relevant_contexts']\n",
+ "\n",
+ "examples = []\n",
+ "for query_id, query in queries.items():\n",
+ " doc_id = relevant_docs[query_id][0]\n",
+ " text = corpus[doc_id]\n",
+ " example = InputExample(texts=[query, text])\n",
+ " examples.append(example)\n",
+ "\n",
+ "BATCH_SIZE = 16\n",
+ "\n",
+ "loader = DataLoader(\n",
+ " examples,\n",
+ " batch_size=BATCH_SIZE,\n",
+ ")\n",
+ "\n",
+ "from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss\n",
+ "\n",
+ "matryoshka_dimensions = [768, 512, 256, 128, 64]\n",
+ "inner_train_loss = MultipleNegativesRankingLoss(model)\n",
+ "train_loss = MatryoshkaLoss(\n",
+ " model, inner_train_loss, matryoshka_dims=matryoshka_dimensions\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "id": "W6FFdpMVpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "from sentence_transformers.evaluation import InformationRetrievalEvaluator\n",
+ "\n",
+ "corpus = val_dataset['corpus']\n",
+ "queries = val_dataset['questions']\n",
+ "relevant_docs = val_dataset['relevant_contexts']\n",
+ "\n",
+ "evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "id": "lfWUVDLgpEC8"
+ },
+ "outputs": [],
+ "source": [
+ "EPOCHS = 10"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 39
+ },
+ "id": "AKU7tS0ApEC8",
+ "outputId": "5243bf3d-1d8d-4814-a5dd-07c8f7993eaf"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/html": [
+ "Display W&B run "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "execution_count": 19
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "import wandb\n",
+ "wandb.init(mode=\"disabled\")"
]
- },
- "metadata": {},
- "output_type": "display_data"
},
{
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[17], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m warmup_steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mlen\u001b[39m(loader) \u001b[38;5;241m*\u001b[39m EPOCHS \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m0.1\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_objectives\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loss\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mEPOCHS\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mwarmup_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwarmup_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mfinetuned_arctic_ft\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mshow_progress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mevaluator\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mevaluator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mevaluation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m50\u001b[39;49m\n\u001b[1;32m 11\u001b[0m \u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/sentence_transformers/fit_mixin.py:385\u001b[0m, in \u001b[0;36mFitMixin.fit\u001b[0;34m(self, train_objectives, evaluator, epochs, steps_per_epoch, scheduler, warmup_steps, optimizer_class, optimizer_params, weight_decay, evaluation_steps, output_path, save_best_model, max_grad_norm, use_amp, callback, show_progress_bar, checkpoint_path, checkpoint_save_steps, checkpoint_save_total_limit)\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output_path \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 383\u001b[0m trainer\u001b[38;5;241m.\u001b[39madd_callback(SaveModelCallback(output_path, evaluator, save_best_model))\n\u001b[0;32m--> 385\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/transformers/trainer.py:2241\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 2239\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 2240\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2241\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2242\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2243\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2244\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2245\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2246\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/transformers/trainer.py:2599\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2595\u001b[0m grad_norm \u001b[38;5;241m=\u001b[39m _grad_norm\n\u001b[1;32m 2597\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_pre_optimizer_step(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m-> 2599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2601\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_optimizer_step(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2603\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39moptimizer_step_was_skipped:\n\u001b[1;32m 2604\u001b[0m \u001b[38;5;66;03m# Delay optimizer scheduling until metrics are generated\u001b[39;00m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/accelerate/optimizer.py:178\u001b[0m, in \u001b[0;36mAcceleratedOptimizer.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_accelerate_step_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 178\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 179\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator_state\u001b[38;5;241m.\u001b[39mdistributed_type \u001b[38;5;241m==\u001b[39m DistributedType\u001b[38;5;241m.\u001b[39mXLA:\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgradient_state\u001b[38;5;241m.\u001b[39mis_xla_gradients_synced \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:140\u001b[0m, in \u001b[0;36mLRScheduler.__init__..patch_track_step_called..wrap_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 138\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 139\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:493\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 488\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 489\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 490\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 491\u001b[0m )\n\u001b[0;32m--> 493\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 494\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 496\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:91\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 90\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 91\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 93\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/adamw.py:243\u001b[0m, in \u001b[0;36mAdamW.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 230\u001b[0m beta1, beta2 \u001b[38;5;241m=\u001b[39m cast(Tuple[\u001b[38;5;28mfloat\u001b[39m, \u001b[38;5;28mfloat\u001b[39m], group[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbetas\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 232\u001b[0m has_complex \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_group(\n\u001b[1;32m 233\u001b[0m group,\n\u001b[1;32m 234\u001b[0m params_with_grad,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 240\u001b[0m state_steps,\n\u001b[1;32m 241\u001b[0m )\n\u001b[0;32m--> 243\u001b[0m \u001b[43madamw\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 244\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams_with_grad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 245\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 246\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 247\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 252\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 253\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlr\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 254\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mweight_decay\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 255\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43meps\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 256\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmaximize\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mforeach\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mforeach\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcapturable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 259\u001b[0m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdifferentiable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[43m \u001b[49m\u001b[43mfused\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgroup\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfused\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgrad_scale\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 262\u001b[0m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfound_inf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 263\u001b[0m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 264\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:154\u001b[0m, in \u001b[0;36m_disable_dynamo_if_unsupported..wrapper..maybe_fallback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m disabled_func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/adamw.py:875\u001b[0m, in \u001b[0;36madamw\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, has_complex, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize)\u001b[0m\n\u001b[1;32m 872\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 873\u001b[0m func \u001b[38;5;241m=\u001b[39m _single_tensor_adamw\n\u001b[0;32m--> 875\u001b[0m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrads\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avgs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43mexp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exp_avg_sqs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 881\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 882\u001b[0m \u001b[43m \u001b[49m\u001b[43mamsgrad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mamsgrad\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 883\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta1\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 884\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta2\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbeta2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 885\u001b[0m \u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 886\u001b[0m \u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweight_decay\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 887\u001b[0m \u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 888\u001b[0m \u001b[43m \u001b[49m\u001b[43mmaximize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmaximize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 889\u001b[0m \u001b[43m \u001b[49m\u001b[43mcapturable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcapturable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 890\u001b[0m \u001b[43m \u001b[49m\u001b[43mdifferentiable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdifferentiable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 891\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrad_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 892\u001b[0m \u001b[43m \u001b[49m\u001b[43mfound_inf\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfound_inf\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 893\u001b[0m \u001b[43m \u001b[49m\u001b[43mhas_complex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhas_complex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 894\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/src/Simplify/.venv/lib/python3.12/site-packages/torch/optim/adamw.py:405\u001b[0m, in \u001b[0;36m_single_tensor_adamw\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable, has_complex)\u001b[0m\n\u001b[1;32m 402\u001b[0m step_t \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 404\u001b[0m \u001b[38;5;66;03m# Perform stepweight decay\u001b[39;00m\n\u001b[0;32m--> 405\u001b[0m \u001b[43mparam\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmul_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mweight_decay\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 407\u001b[0m device \u001b[38;5;241m=\u001b[39m param\u001b[38;5;241m.\u001b[39mdevice\n\u001b[1;32m 409\u001b[0m device \u001b[38;5;241m=\u001b[39m param\u001b[38;5;241m.\u001b[39mdevice\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 444,
+ "referenced_widgets": [
+ "fe3544500fe3490ea8160ad8f97dd83b",
+ "1436400ef9f9436b893b2648595470ed",
+ "a5629e7f756241799849d12b80ea9d2f",
+ "76f58839f7234aa4aaff371e3ebeef0f",
+ "760803fdf5714bd286fde5bfe6d19028",
+ "7501478ab7dd4bcaa9f86deb1cf34389",
+ "743ba854952e487c93f0627dee07bc23",
+ "232877e6cf6b4d7192fc4f53cc66a517",
+ "155509672c344ad5882c70df24152a02",
+ "cb3c554ebacd430aa3f8d87d7ccbba30",
+ "66a090103ac74ea98af40ad3645c1ba8"
+ ]
+ },
+ "id": "5yJo1jKlpEC8",
+ "outputId": "4f8dfa01-5030-4e1f-b3cf-ee60dad979d7"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Computing widget examples: 0%| | 0/1 [00:00, ?example/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "fe3544500fe3490ea8160ad8f97dd83b"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [40/40 00:27, Epoch 10/10]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Step \n",
+ " Training Loss \n",
+ " Validation Loss \n",
+ " Cosine Accuracy@1 \n",
+ " Cosine Accuracy@3 \n",
+ " Cosine Accuracy@5 \n",
+ " Cosine Accuracy@10 \n",
+ " Cosine Precision@1 \n",
+ " Cosine Precision@3 \n",
+ " Cosine Precision@5 \n",
+ " Cosine Precision@10 \n",
+ " Cosine Recall@1 \n",
+ " Cosine Recall@3 \n",
+ " Cosine Recall@5 \n",
+ " Cosine Recall@10 \n",
+ " Cosine Ndcg@10 \n",
+ " Cosine Mrr@10 \n",
+ " Cosine Map@100 \n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " 4 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 8 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 12 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 16 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 20 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 24 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 28 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 32 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 36 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ " 40 \n",
+ " No log \n",
+ " No log \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 0.333333 \n",
+ " 0.200000 \n",
+ " 0.100000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " 1.000000 \n",
+ " \n",
+ " \n",
+ "
"
+ ]
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "import datasets\n",
+ "\n",
+ "warmup_steps = int(len(loader) * EPOCHS * 0.1)\n",
+ "\n",
+ "model.fit(\n",
+ " train_objectives=[(loader, train_loss)],\n",
+ " epochs=EPOCHS,\n",
+ " warmup_steps=warmup_steps,\n",
+ " output_path='finetuned_arctic_ft',\n",
+ " show_progress_bar=True,\n",
+ " evaluator=evaluator,\n",
+ " evaluation_steps=50\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17,
+ "referenced_widgets": [
+ "cfa4b24acf1448658fce192d13fb85c7",
+ "d1517b337e1945dab7d74747a532490e",
+ "96998de625454d9f931147916b68111b",
+ "50fff0e983184c2698e9283bcf15f6dd",
+ "67dd35582b944946b86c9b2068590541",
+ "45f142dbf1094059950d919db640f26d",
+ "fbf244b28188469cbcb9ea45606a5ed9",
+ "9565038f3a5e4ba9a75068ec1202e6ce",
+ "c1ca1e0b8d674958b6bcca0f7ab233af",
+ "a5e4c1fa2bb24065bf24b554af13a262",
+ "e7c075f1cd114e21a98bd327c40bf031",
+ "6eba55470c4d4853a9c5c0973a1ee824",
+ "e63de24055d24199b01aac0110764fb6",
+ "b613cff44000413cac23d92700d0f655",
+ "0df923ccb7cd4033acdd31a3f8c314dd",
+ "91ecbd26c4044e8390df9ea10f39f4aa",
+ "cd55bad731bb447e8240a9bb75ae37c7",
+ "fa7d463c6c024650987153f678953d46",
+ "faa9359e3a3d4af0bdc50e1a846c5b78",
+ "ecfb24224b454baabeae4dfb3c9faa22"
+ ]
+ },
+ "id": "1g9JJiNTpEC8",
+ "outputId": "7b2cc05e-5f7e-4576-b562-a463df90fcba"
+ },
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "VBox(children=(HTML(value='
Copy a token from your Hugging Face\ntokens page and paste it below. Immediately click login after copying\nyour token or it might be stored in plain text in this notebook file. "
+ }
+ },
+ "96998de625454d9f931147916b68111b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "PasswordModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "PasswordModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "PasswordView",
+ "continuous_update": true,
+ "description": "Token:",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_a5e4c1fa2bb24065bf24b554af13a262",
+ "placeholder": "",
+ "style": "IPY_MODEL_e7c075f1cd114e21a98bd327c40bf031",
+ "value": ""
+ }
+ },
+ "50fff0e983184c2698e9283bcf15f6dd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "CheckboxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "CheckboxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "CheckboxView",
+ "description": "Add token as git credential?",
+ "description_tooltip": null,
+ "disabled": false,
+ "indent": true,
+ "layout": "IPY_MODEL_6eba55470c4d4853a9c5c0973a1ee824",
+ "style": "IPY_MODEL_e63de24055d24199b01aac0110764fb6",
+ "value": true
+ }
+ },
+ "67dd35582b944946b86c9b2068590541": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ButtonModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ButtonView",
+ "button_style": "",
+ "description": "Login",
+ "disabled": false,
+ "icon": "",
+ "layout": "IPY_MODEL_b613cff44000413cac23d92700d0f655",
+ "style": "IPY_MODEL_0df923ccb7cd4033acdd31a3f8c314dd",
+ "tooltip": ""
+ }
+ },
+ "45f142dbf1094059950d919db640f26d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_91ecbd26c4044e8390df9ea10f39f4aa",
+ "placeholder": "",
+ "style": "IPY_MODEL_cd55bad731bb447e8240a9bb75ae37c7",
+ "value": "\nPro Tip: If you don't already have one, you can create a dedicated\n'notebooks' token with 'write' access, that you can then easily reuse for all\nnotebooks. "
+ }
+ },
+ "fbf244b28188469cbcb9ea45606a5ed9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": "center",
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": "flex",
+ "flex": null,
+ "flex_flow": "column",
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "50%"
+ }
+ },
+ "9565038f3a5e4ba9a75068ec1202e6ce": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c1ca1e0b8d674958b6bcca0f7ab233af": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a5e4c1fa2bb24065bf24b554af13a262": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e7c075f1cd114e21a98bd327c40bf031": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6eba55470c4d4853a9c5c0973a1ee824": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e63de24055d24199b01aac0110764fb6": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b613cff44000413cac23d92700d0f655": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0df923ccb7cd4033acdd31a3f8c314dd": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ButtonStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "button_color": null,
+ "font_weight": ""
+ }
+ },
+ "91ecbd26c4044e8390df9ea10f39f4aa": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cd55bad731bb447e8240a9bb75ae37c7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fa7d463c6c024650987153f678953d46": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "LabelModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "LabelModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "LabelView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_faa9359e3a3d4af0bdc50e1a846c5b78",
+ "placeholder": "",
+ "style": "IPY_MODEL_ecfb24224b454baabeae4dfb3c9faa22",
+ "value": "Connecting..."
+ }
+ },
+ "faa9359e3a3d4af0bdc50e1a846c5b78": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "ecfb24224b454baabeae4dfb3c9faa22": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "838a80e21ed14386b752cfcde4411794": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HBoxModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e0be2afe8a394ade988451c70c72365f",
+ "IPY_MODEL_e947bfcd5f564b2fbd5b2e1cf2456b67",
+ "IPY_MODEL_54eb0595ae514931940ddff693480593"
+ ],
+ "layout": "IPY_MODEL_bd0c87b78eab49e6b3b6798514409baf"
+ }
+ },
+ "e0be2afe8a394ade988451c70c72365f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e05d3cf5258042d28a1902ed3572ac06",
+ "placeholder": "",
+ "style": "IPY_MODEL_10e99c32989f4cbc98956b3e79fdce29",
+ "value": "model.safetensors: 100%"
+ }
+ },
+ "e947bfcd5f564b2fbd5b2e1cf2456b67": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "FloatProgressModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9057b3adcfc84ab9be91ae87d6194ece",
+ "max": 1336413848,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_1db227928181429880c22653ba282906",
+ "value": 1336413848
+ }
+ },
+ "54eb0595ae514931940ddff693480593": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "HTMLModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_4a7bee32a1684ea3803c9f1613dc667f",
+ "placeholder": "",
+ "style": "IPY_MODEL_3e7cf2771b4b443781bc4bc0d77d039a",
+ "value": " 1.34G/1.34G [00:27<00:00, 51.1MB/s]"
+ }
+ },
+ "bd0c87b78eab49e6b3b6798514409baf": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e05d3cf5258042d28a1902ed3572ac06": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "10e99c32989f4cbc98956b3e79fdce29": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9057b3adcfc84ab9be91ae87d6194ece": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1db227928181429880c22653ba282906": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "ProgressStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "4a7bee32a1684ea3803c9f1613dc667f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_name": "LayoutModel",
+ "model_module_version": "1.2.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3e7cf2771b4b443781bc4bc0d77d039a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_name": "DescriptionStyleModel",
+ "model_module_version": "1.5.0",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ }
+ }
}
- ],
- "source": [
- "warmup_steps = int(len(loader) * EPOCHS * 0.1)\n",
- "\n",
- "model.fit(\n",
- " train_objectives=[(loader, train_loss)],\n",
- " epochs=EPOCHS,\n",
- " warmup_steps=warmup_steps,\n",
- " output_path='finetuned_arctic_ft',\n",
- " show_progress_bar=True,\n",
- " evaluator=evaluator,\n",
- " evaluation_steps=50\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from huggingface_hub import notebook_login\n",
- "\n",
- "notebook_login()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "hf_username = \"Rsr2425\"\n",
- "model.push_to_hub(f\"{hf_username}/simplify-embeddings\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
},
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.12.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file