File size: 3,515 Bytes
5f94711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "fZnNN5kHlpVa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "%cd /content/files"
      ],
      "metadata": {
        "id": "3FNd407Xmmmf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -U bitsandbytes\n"
      ],
      "metadata": {
        "id": "JUhd_t-7pDYY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "puAjgEBAlYN0"
      },
      "outputs": [],
      "source": [
        "# Install required packages\n",
        "!pip install torch transformers safetensors accelerate\n",
        "\n",
        "import torch\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "import json\n",
        "\n",
        "class ChatBot:\n",
        "    def __init__(self, model_path):\n",
        "        # Initialize tokenizer and model\n",
        "        self.tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
        "        self.model = AutoModelForCausalLM.from_pretrained(\n",
        "            model_path,\n",
        "            device_map=\"auto\",\n",
        "            torch_dtype=torch.float16\n",
        "        )\n",
        "\n",
        "    def chat(self, user_input, max_length=1000):\n",
        "        # Prepare input\n",
        "        inputs = self.tokenizer(user_input, return_tensors=\"pt\").to(self.model.device)\n",
        "\n",
        "        # Generate response\n",
        "        with torch.no_grad():\n",
        "            outputs = self.model.generate(\n",
        "                **inputs,\n",
        "                max_length=max_length,\n",
        "                num_return_sequences=1,\n",
        "                temperature=0.7,\n",
        "                do_sample=True,\n",
        "                pad_token_id=self.tokenizer.eos_token_id\n",
        "            )\n",
        "\n",
        "        # Decode and return response\n",
        "        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "        return response\n",
        "\n",
        "# Simple chat interface\n",
        "def start_chat():\n",
        "    print(\"Initializing chatbot... This may take a few minutes.\")\n",
        "    chatbot = ChatBot(\"/content/files\")  # Use the directory containing model files\n",
        "\n",
        "    print(\"\\nChat initialized! Type 'quit' to exit.\")\n",
        "    while True:\n",
        "        user_input = input(\"\\nYou: \")\n",
        "        if user_input.lower() == 'quit':\n",
        "            break\n",
        "\n",
        "        response = chatbot.chat(user_input)\n",
        "        print(f\"\\nBot: {response}\")\n",
        "\n",
        "# Start the chat\n",
        "if __name__ == \"__main__\":\n",
        "    start_chat()"
      ]
    }
  ]
}