Thomas (Tom) Gardos commited on
Commit
4bdb9ef
2 Parent(s): 81a188d 8e1bf6f

Merge pull request #90 from DL4DS/dev2main

Browse files
Files changed (48) hide show
  1. .flake8 +3 -0
  2. .gitattributes +1 -0
  3. .github/workflows/code_quality_check.yml +33 -0
  4. .gitignore +3 -1
  5. Dockerfile +8 -1
  6. README.md +18 -39
  7. code/.chainlit/config.toml +8 -6
  8. code/__init__.py +0 -1
  9. code/app.py +351 -0
  10. code/chainlit.md +1 -6
  11. code/chainlit_base.py +484 -0
  12. code/main.py +212 -67
  13. code/modules/chat/chat_model_loader.py +2 -9
  14. code/modules/chat/helpers.py +6 -4
  15. code/modules/chat/langchain/__init__.py +0 -0
  16. code/modules/chat/langchain/langchain_rag.py +16 -12
  17. code/modules/chat/langchain/utils.py +12 -34
  18. code/modules/chat/llm_tutor.py +10 -7
  19. code/modules/chat_processor/helpers.py +245 -0
  20. code/modules/chat_processor/literal_ai.py +1 -38
  21. code/modules/config/config.yml +3 -3
  22. code/modules/config/constants.py +14 -3
  23. code/modules/config/project_config.yml +7 -0
  24. code/modules/dataloader/data_loader.py +96 -55
  25. code/modules/dataloader/helpers.py +13 -6
  26. code/modules/dataloader/pdf_readers/gpt.py +27 -19
  27. code/modules/dataloader/pdf_readers/llama.py +24 -23
  28. code/modules/dataloader/webpage_crawler.py +5 -3
  29. code/modules/retriever/helpers.py +0 -1
  30. code/modules/vectorstore/colbert.py +3 -2
  31. code/modules/vectorstore/embedding_model_loader.py +1 -7
  32. code/modules/vectorstore/faiss.py +10 -7
  33. code/modules/vectorstore/raptor.py +1 -4
  34. code/modules/vectorstore/store_manager.py +21 -14
  35. code/public/avatars/{ai-tutor.png → ai_tutor.png} +0 -0
  36. code/public/space.jpg +3 -0
  37. code/public/test.css +0 -19
  38. code/templates/cooldown.html +181 -0
  39. code/templates/dashboard.html +145 -0
  40. code/templates/error.html +95 -0
  41. code/templates/error_404.html +80 -0
  42. code/templates/login.html +132 -0
  43. code/templates/logout.html +21 -0
  44. docs/README.md +0 -51
  45. docs/contribute.md +33 -0
  46. docs/setup.md +127 -0
  47. pyproject.toml +2 -0
  48. requirements.txt +12 -1
.flake8 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 88
3
+ extend-ignore = E203, E266, E501, W503
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.github/workflows/code_quality_check.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Code Quality and Security Checks
2
+
3
+ on:
4
+ push:
5
+ branches: [ main, dev_branch ]
6
+ pull_request:
7
+ branches: [ main, dev_branch ]
8
+
9
+ jobs:
10
+ code-quality:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+
15
+ - name: Set up Python
16
+ uses: actions/setup-python@v4
17
+ with:
18
+ python-version: '3.11'
19
+
20
+ - name: Install dependencies
21
+ run: |
22
+ python -m pip install --upgrade pip
23
+ pip install flake8 black bandit
24
+
25
+ - name: Run Black
26
+ run: black --check .
27
+
28
+ - name: Run Flake8
29
+ run: flake8 .
30
+
31
+ - name: Run Bandit
32
+ run: |
33
+ bandit -r .
.gitignore CHANGED
@@ -165,7 +165,9 @@ cython_debug/
165
  .ragatouille/*
166
  */__pycache__/*
167
  .chainlit/translations/
 
168
  storage/logs/*
169
  vectorstores/*
170
 
171
- */.files/*
 
 
165
  .ragatouille/*
166
  */__pycache__/*
167
  .chainlit/translations/
168
+ code/.chainlit/translations/
169
  storage/logs/*
170
  vectorstores/*
171
 
172
+ */.files/*
173
+ code/storage/models/
Dockerfile CHANGED
@@ -26,6 +26,13 @@ WORKDIR /code/code
26
 
27
  RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
28
  RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
 
 
 
 
 
 
 
29
 
30
  # Default command to run the application
31
- CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && chainlit run main.py --host 0.0.0.0 --port 7860"]
 
26
 
27
  RUN --mount=type=secret,id=HUGGINGFACEHUB_API_TOKEN,mode=0444,required=true
28
  RUN --mount=type=secret,id=OPENAI_API_KEY,mode=0444,required=true
29
+ RUN --mount=type=secret,id=CHAINLIT_URL,mode=0444,required=true
30
+ RUN --mount=type=secret,id=LITERAL_API_URL,mode=0444,required=true
31
+ RUN --mount=type=secret,id=LLAMA_CLOUD_API_KEY,mode=0444,required=true
32
+ RUN --mount=type=secret,id=OAUTH_GOOGLE_CLIENT_ID,mode=0444,required=true
33
+ RUN --mount=type=secret,id=OAUTH_GOOGLE_CLIENT_SECRET,mode=0444,required=true
34
+ RUN --mount=type=secret,id=LITERAL_API_KEY_LOGGING,mode=0444,required=true
35
+ RUN --mount=type=secret,id=CHAINLIT_AUTH_SECRET,mode=0444,required=true
36
 
37
  # Default command to run the application
38
+ CMD ["sh", "-c", "python -m modules.vectorstore.store_manager && uvicorn app:app --host 0.0.0.0 --port 7860"]
README.md CHANGED
@@ -15,10 +15,14 @@ You can find a "production" implementation of the Tutor running live at [DL4DS T
15
  Hugging Face [Space](https://huggingface.co/spaces/dl4ds/dl4ds_tutor). It is pushed automatically from the `main` branch of this repo by this
16
  [Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/main/.github/workflows/push_to_hf_space.yml) upon a push to `main`.
17
 
18
- A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](https://dl4ds-tutor-dev.hf.space) from this Hugging Face
 
19
  [Space](https://huggingface.co/spaces/dl4ds/tutor_dev). It is pushed automatically from the `dev_branch` branch of this repo by this
20
  [Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/dev_branch/.github/workflows/push_to_hf_space_prototype.yml) upon a push to `dev_branch`.
21
 
 
 
 
22
 
23
  ## Running Locally
24
 
@@ -34,7 +38,7 @@ A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](htt
34
  3. **To test Data Loading (Optional)**
35
  ```bash
36
  cd code
37
- python -m modules.dataloader.data_loader
38
  ```
39
 
40
  4. **Create the Vector Database**
@@ -43,47 +47,16 @@ A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](htt
43
  python -m modules.vectorstore.store_manager
44
  ```
45
  - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
46
- - Alternatively, you can set `["vectorstore"]["embedd_files"]` to `True` in the `code/modules/config/config.yaml` file, which will embed files from the storage directory every time you run the below chainlit command.
47
 
48
- 5. **Run the Chainlit App**
49
  ```bash
50
- chainlit run main.py
 
51
  ```
52
 
53
- See the [docs](https://github.com/DL4DS/dl4ds_tutor/tree/main/docs) for more information.
54
-
55
- ## File Structure
56
-
57
- ```plaintext
58
- code/
59
- ├── modules
60
- │ ├── chat # Contains the chatbot implementation
61
- │ ├── chat_processor # Contains the implementation to process and log the conversations
62
- │ ├── config # Contains the configuration files
63
- │ ├── dataloader # Contains the implementation to load the data from the storage directory
64
- │ ├── retriever # Contains the implementation to create the retriever
65
- │ └── vectorstore # Contains the implementation to create the vector database
66
- ├── public
67
- │ ├── logo_dark.png # Dark theme logo
68
- │ ├── logo_light.png # Light theme logo
69
- │ └── test.css # Custom CSS file
70
- └── main.py
71
-
72
-
73
- docs/ # Contains the documentation to the codebase and methods used
74
 
75
- storage/
76
- ├── data # Store files and URLs here
77
- ├── logs # Logs directory, includes logs on vector DB creation, tutor logs, and chunks logged in JSON files
78
- └── models # Local LLMs are loaded from here
79
-
80
- vectorstores/ # Stores the created vector databases
81
-
82
- .env # This needs to be created, store the API keys here
83
- ```
84
- - `code/modules/vectorstore/vectorstore.py`: Instantiates the `VectorStore` class to create the vector database.
85
- - `code/modules/vectorstore/store_manager.py`: Instantiates the `VectorStoreManager:` class to manage the vector database, and all associated methods.
86
- - `code/modules/retriever/retriever.py`: Instantiates the `Retriever` class to create the retriever.
87
 
88
 
89
  ## Docker
@@ -97,4 +70,10 @@ docker run -it --rm -p 8000:8000 dev
97
 
98
  ## Contributing
99
 
100
- Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the main branch.
 
 
 
 
 
 
 
15
  Hugging Face [Space](https://huggingface.co/spaces/dl4ds/dl4ds_tutor). It is pushed automatically from the `main` branch of this repo by this
16
  [Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/main/.github/workflows/push_to_hf_space.yml) upon a push to `main`.
17
 
18
+
19
+ A "development" version of the Tutor is running live at [DL4DS Tutor -- Dev](https://dl4ds-tutor-dev.hf.space/) from this Hugging Face
20
  [Space](https://huggingface.co/spaces/dl4ds/tutor_dev). It is pushed automatically from the `dev_branch` branch of this repo by this
21
  [Actions Workflow](https://github.com/DL4DS/dl4ds_tutor/blob/dev_branch/.github/workflows/push_to_hf_space_prototype.yml) upon a push to `dev_branch`.
22
 
23
+ ## Setup
24
+
25
+ Please visit [setup](https://dl4ds.github.io/dl4ds_tutor/guide/setup/) for more information on setting up the project.
26
 
27
  ## Running Locally
28
 
 
38
  3. **To test Data Loading (Optional)**
39
  ```bash
40
  cd code
41
+ python -m modules.dataloader.data_loader --links "your_pdf_link"
42
  ```
43
 
44
  4. **Create the Vector Database**
 
47
  python -m modules.vectorstore.store_manager
48
  ```
49
  - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
 
50
 
51
+ 6. **Run the FastAPI App**
52
  ```bash
53
+ cd code
54
+ uvicorn app:app --port 7860
55
  ```
56
 
57
+ ## Documentation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ Please visit the [docs](https://dl4ds.github.io/dl4ds_tutor/) for more information.
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  ## Docker
 
70
 
71
  ## Contributing
72
 
73
+ Please create an issue if you have any suggestions or improvements, and start working on it by creating a branch and by making a pull request to the `dev_branch`.
74
+
75
+ Please visit [contribute](https://dl4ds.github.io/dl4ds_tutor/guide/contribute/) for more information on contributing.
76
+
77
+ ## Future Work
78
+
79
+ For more information on future work, please visit [roadmap](https://dl4ds.github.io/dl4ds_tutor/guide/readmap/).
code/.chainlit/config.toml CHANGED
@@ -20,7 +20,7 @@ allow_origins = ["*"]
20
 
21
  [features]
22
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
23
- unsafe_allow_html = false
24
 
25
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
26
  latex = true
@@ -49,6 +49,8 @@ auto_tag_thread = true
49
  # Sample rate of the audio
50
  sample_rate = 44100
51
 
 
 
52
  [UI]
53
  # Name of the assistant.
54
  name = "AI Tutor"
@@ -59,11 +61,11 @@ name = "AI Tutor"
59
  # Large size content are by default collapsed for a cleaner ui
60
  default_collapse_content = true
61
 
62
- # Hide the chain of thought details from the user in the UI.
63
- hide_cot = true
64
 
65
  # Link to your github repo. This will add a github button in the UI's header.
66
- # github = "https://github.com/DL4DS/dl4ds_tutor"
67
 
68
  # Specify a CSS file that can be used to customize the user interface.
69
  # The CSS file can be served from the public directory or via an external link.
@@ -85,7 +87,7 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
85
  # custom_build = "./public/build"
86
 
87
  [UI.theme]
88
- default = "dark"
89
  #layout = "wide"
90
  #font_family = "Inter, sans-serif"
91
  # Override default MUI light theme. (Check theme.ts)
@@ -115,4 +117,4 @@ custom_meta_image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/f/
115
  #secondary = "#BDBDBD"
116
 
117
  [meta]
118
- generated_by = "1.1.304"
 
20
 
21
  [features]
22
  # Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
23
+ unsafe_allow_html = true
24
 
25
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
26
  latex = true
 
49
  # Sample rate of the audio
50
  sample_rate = 44100
51
 
52
+ edit_message = true
53
+
54
  [UI]
55
  # Name of the assistant.
56
  name = "AI Tutor"
 
61
  # Large size content are by default collapsed for a cleaner ui
62
  default_collapse_content = true
63
 
64
+ # Chain of Thought (CoT) display mode. Can be "hidden", "tool_call" or "full".
65
+ cot = "hidden"
66
 
67
  # Link to your github repo. This will add a github button in the UI's header.
68
+ github = "https://github.com/DL4DS/dl4ds_tutor"
69
 
70
  # Specify a CSS file that can be used to customize the user interface.
71
  # The CSS file can be served from the public directory or via an external link.
 
87
  # custom_build = "./public/build"
88
 
89
  [UI.theme]
90
+ default = "light"
91
  #layout = "wide"
92
  #font_family = "Inter, sans-serif"
93
  # Override default MUI light theme. (Check theme.ts)
 
117
  #secondary = "#BDBDBD"
118
 
119
  [meta]
120
+ generated_by = "1.1.402"
code/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .modules import *
 
 
code/app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response, HTTPException
2
+ from fastapi.responses import HTMLResponse, RedirectResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from google.oauth2 import id_token
5
+ from google.auth.transport import requests as google_requests
6
+ from google_auth_oauthlib.flow import Flow
7
+ from chainlit.utils import mount_chainlit
8
+ import secrets
9
+ import json
10
+ import base64
11
+ from modules.config.constants import (
12
+ OAUTH_GOOGLE_CLIENT_ID,
13
+ OAUTH_GOOGLE_CLIENT_SECRET,
14
+ CHAINLIT_URL,
15
+ GITHUB_REPO,
16
+ DOCS_WEBSITE,
17
+ ALL_TIME_TOKENS_ALLOCATED,
18
+ TOKENS_LEFT,
19
+ )
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.staticfiles import StaticFiles
22
+ from modules.chat_processor.helpers import (
23
+ get_user_details,
24
+ get_time,
25
+ reset_tokens_for_user,
26
+ check_user_cooldown,
27
+ update_user_info,
28
+ )
29
+
30
+ GOOGLE_CLIENT_ID = OAUTH_GOOGLE_CLIENT_ID
31
+ GOOGLE_CLIENT_SECRET = OAUTH_GOOGLE_CLIENT_SECRET
32
+ GOOGLE_REDIRECT_URI = f"{CHAINLIT_URL}/auth/oauth/google/callback"
33
+
34
+ app = FastAPI()
35
+ app.mount("/public", StaticFiles(directory="public"), name="public")
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"], # Update with appropriate origins
39
+ allow_methods=["*"],
40
+ allow_headers=["*"], # or specify the headers you want to allow
41
+ expose_headers=["X-User-Info"], # Expose the custom header
42
+ )
43
+
44
+ templates = Jinja2Templates(directory="templates")
45
+ session_store = {}
46
+ CHAINLIT_PATH = "/chainlit_tutor"
47
+
48
+ # only admin is given any additional permissions for now -- no limits on tokens
49
+ USER_ROLES = {
50
+ "[email protected]": ["instructor", "bu"],
51
+ "[email protected]": ["admin", "instructor", "bu"],
52
+ "[email protected]": ["instructor", "bu"],
53
+ "[email protected]": ["guest"],
54
+ # Add more users and roles as needed
55
+ }
56
+
57
+ # Create a Google OAuth flow
58
+ flow = Flow.from_client_config(
59
+ {
60
+ "web": {
61
+ "client_id": GOOGLE_CLIENT_ID,
62
+ "client_secret": GOOGLE_CLIENT_SECRET,
63
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
64
+ "token_uri": "https://oauth2.googleapis.com/token",
65
+ "redirect_uris": [GOOGLE_REDIRECT_URI],
66
+ "scopes": [
67
+ "openid",
68
+ # "https://www.googleapis.com/auth/userinfo.email",
69
+ # "https://www.googleapis.com/auth/userinfo.profile",
70
+ ],
71
+ }
72
+ },
73
+ scopes=[
74
+ "openid",
75
+ "https://www.googleapis.com/auth/userinfo.email",
76
+ "https://www.googleapis.com/auth/userinfo.profile",
77
+ ],
78
+ redirect_uri=GOOGLE_REDIRECT_URI,
79
+ )
80
+
81
+
82
+ def get_user_role(username: str):
83
+ return USER_ROLES.get(username, ["guest"]) # Default to "guest" role
84
+
85
+
86
+ async def get_user_info_from_cookie(request: Request):
87
+ user_info_encoded = request.cookies.get("X-User-Info")
88
+ if user_info_encoded:
89
+ try:
90
+ user_info_json = base64.b64decode(user_info_encoded).decode()
91
+ return json.loads(user_info_json)
92
+ except Exception as e:
93
+ print(f"Error decoding user info: {e}")
94
+ return None
95
+ return None
96
+
97
+
98
+ async def del_user_info_from_cookie(request: Request, response: Response):
99
+ # Delete cookies from the response
100
+ response.delete_cookie("X-User-Info")
101
+ response.delete_cookie("session_token")
102
+ # Get the session token from the request cookies
103
+ session_token = request.cookies.get("session_token")
104
+ # Check if the session token exists in the session_store before deleting
105
+ if session_token and session_token in session_store:
106
+ del session_store[session_token]
107
+
108
+
109
+ def get_user_info(request: Request):
110
+ session_token = request.cookies.get("session_token")
111
+ if session_token and session_token in session_store:
112
+ return session_store[session_token]
113
+ return None
114
+
115
+
116
+ @app.get("/", response_class=HTMLResponse)
117
+ async def login_page(request: Request):
118
+ user_info = await get_user_info_from_cookie(request)
119
+ if user_info and user_info.get("google_signed_in"):
120
+ return RedirectResponse("/post-signin")
121
+ return templates.TemplateResponse(
122
+ "login.html",
123
+ {"request": request, "GITHUB_REPO": GITHUB_REPO, "DOCS_WEBSITE": DOCS_WEBSITE},
124
+ )
125
+
126
+
127
+ # @app.get("/login/guest")
128
+ # async def login_guest():
129
+ # username = "guest"
130
+ # session_token = secrets.token_hex(16)
131
+ # unique_session_id = secrets.token_hex(8)
132
+ # username = f"{username}_{unique_session_id}"
133
+ # session_store[session_token] = {
134
+ # "email": username,
135
+ # "name": "Guest",
136
+ # "profile_image": "",
137
+ # "google_signed_in": False, # Ensure guest users do not have this flag
138
+ # }
139
+ # user_info_json = json.dumps(session_store[session_token])
140
+ # user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
141
+
142
+ # # Set cookies
143
+ # response = RedirectResponse(url="/post-signin", status_code=303)
144
+ # response.set_cookie(key="session_token", value=session_token)
145
+ # response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
146
+ # return response
147
+
148
+
149
+ @app.get("/login/google")
150
+ async def login_google(request: Request):
151
+ # Clear any existing session cookies to avoid conflicts with guest sessions
152
+ response = RedirectResponse(url="/post-signin")
153
+ response.delete_cookie(key="session_token")
154
+ response.delete_cookie(key="X-User-Info")
155
+
156
+ user_info = await get_user_info_from_cookie(request)
157
+ # Check if user is already signed in using Google
158
+ if user_info and user_info.get("google_signed_in"):
159
+ return RedirectResponse("/post-signin")
160
+ else:
161
+ authorization_url, _ = flow.authorization_url(prompt="consent")
162
+ return RedirectResponse(authorization_url, headers=response.headers)
163
+
164
+
165
+ @app.get("/auth/oauth/google/callback")
166
+ async def auth_google(request: Request):
167
+ try:
168
+ flow.fetch_token(code=request.query_params.get("code"))
169
+ credentials = flow.credentials
170
+ user_info = id_token.verify_oauth2_token(
171
+ credentials.id_token, google_requests.Request(), GOOGLE_CLIENT_ID
172
+ )
173
+
174
+ email = user_info["email"]
175
+ name = user_info.get("name", "")
176
+ profile_image = user_info.get("picture", "")
177
+ role = get_user_role(email)
178
+
179
+ session_token = secrets.token_hex(16)
180
+ session_store[session_token] = {
181
+ "email": email,
182
+ "name": name,
183
+ "profile_image": profile_image,
184
+ "google_signed_in": True, # Set this flag to True for Google-signed users
185
+ }
186
+
187
+ # add literalai user info to session store to be sent to chainlit
188
+ literalai_user = await get_user_details(email)
189
+ session_store[session_token]["literalai_info"] = literalai_user.to_dict()
190
+ session_store[session_token]["literalai_info"]["metadata"]["role"] = role
191
+
192
+ user_info_json = json.dumps(session_store[session_token])
193
+ user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
194
+
195
+ # Set cookies
196
+ response = RedirectResponse(url="/post-signin", status_code=303)
197
+ response.set_cookie(key="session_token", value=session_token)
198
+ response.set_cookie(
199
+ key="X-User-Info", value=user_info_encoded, httponly=True
200
+ ) # TODO: is the flag httponly=True necessary?
201
+ return response
202
+ except Exception as e:
203
+ print(f"Error during Google OAuth callback: {e}")
204
+ return RedirectResponse(url="/", status_code=302)
205
+
206
+
207
+ @app.get("/cooldown")
208
+ async def cooldown(request: Request):
209
+ user_info = await get_user_info_from_cookie(request)
210
+ user_details = await get_user_details(user_info["email"])
211
+ current_datetime = get_time()
212
+ cooldown, cooldown_end_time = await check_user_cooldown(
213
+ user_details, current_datetime
214
+ )
215
+ print(f"User in cooldown: {cooldown}")
216
+ print(f"Cooldown end time: {cooldown_end_time}")
217
+ if cooldown and "admin" not in get_user_role(user_info["email"]):
218
+ return templates.TemplateResponse(
219
+ "cooldown.html",
220
+ {
221
+ "request": request,
222
+ "username": user_info["email"],
223
+ "role": get_user_role(user_info["email"]),
224
+ "cooldown_end_time": cooldown_end_time,
225
+ "tokens_left": user_details.metadata["tokens_left"],
226
+ },
227
+ )
228
+ else:
229
+ user_details.metadata["in_cooldown"] = False
230
+ await update_user_info(user_details)
231
+ await reset_tokens_for_user(user_details)
232
+ return RedirectResponse("/post-signin")
233
+
234
+
235
+ @app.get("/post-signin", response_class=HTMLResponse)
236
+ async def post_signin(request: Request):
237
+ user_info = await get_user_info_from_cookie(request)
238
+ if not user_info:
239
+ user_info = get_user_info(request)
240
+ user_details = await get_user_details(user_info["email"])
241
+ current_datetime = get_time()
242
+ user_details.metadata["last_login"] = current_datetime
243
+ # if new user, set the number of tries
244
+ if "tokens_left" not in user_details.metadata:
245
+ user_details.metadata["tokens_left"] = (
246
+ TOKENS_LEFT # set the number of tokens left for the new user
247
+ )
248
+ if "last_message_time" not in user_details.metadata:
249
+ user_details.metadata["last_message_time"] = current_datetime
250
+ if "all_time_tokens_allocated" not in user_details.metadata:
251
+ user_details.metadata["all_time_tokens_allocated"] = ALL_TIME_TOKENS_ALLOCATED
252
+ if "in_cooldown" not in user_details.metadata:
253
+ user_details.metadata["in_cooldown"] = False
254
+ await update_user_info(user_details)
255
+
256
+ if "last_message_time" in user_details.metadata and "admin" not in get_user_role(
257
+ user_info["email"]
258
+ ):
259
+ cooldown, _ = await check_user_cooldown(user_details, current_datetime)
260
+ if cooldown:
261
+ user_details.metadata["in_cooldown"] = True
262
+ return RedirectResponse("/cooldown")
263
+ else:
264
+ user_details.metadata["in_cooldown"] = False
265
+ await reset_tokens_for_user(user_details)
266
+
267
+ if user_info:
268
+ username = user_info["email"]
269
+ role = get_user_role(username)
270
+ jwt_token = request.cookies.get("X-User-Info")
271
+ return templates.TemplateResponse(
272
+ "dashboard.html",
273
+ {
274
+ "request": request,
275
+ "username": username,
276
+ "role": role,
277
+ "jwt_token": jwt_token,
278
+ "tokens_left": user_details.metadata["tokens_left"],
279
+ "all_time_tokens_allocated": user_details.metadata[
280
+ "all_time_tokens_allocated"
281
+ ],
282
+ "total_tokens_allocated": ALL_TIME_TOKENS_ALLOCATED,
283
+ },
284
+ )
285
+ return RedirectResponse("/")
286
+
287
+
288
+ @app.get("/start-tutor")
289
+ @app.post("/start-tutor")
290
+ async def start_tutor(request: Request):
291
+ user_info = await get_user_info_from_cookie(request)
292
+ if user_info:
293
+ user_info_json = json.dumps(user_info)
294
+ user_info_encoded = base64.b64encode(user_info_json.encode()).decode()
295
+
296
+ response = RedirectResponse(CHAINLIT_PATH, status_code=303)
297
+ response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
298
+ return response
299
+
300
+ return RedirectResponse(url="/")
301
+
302
+
303
+ @app.exception_handler(HTTPException)
304
+ async def http_exception_handler(request: Request, exc: HTTPException):
305
+ if exc.status_code == 404:
306
+ return templates.TemplateResponse(
307
+ "error_404.html", {"request": request}, status_code=404
308
+ )
309
+ return templates.TemplateResponse(
310
+ "error.html",
311
+ {"request": request, "error": str(exc)},
312
+ status_code=exc.status_code,
313
+ )
314
+
315
+
316
+ @app.exception_handler(Exception)
317
+ async def exception_handler(request: Request, exc: Exception):
318
+ return templates.TemplateResponse(
319
+ "error.html", {"request": request, "error": str(exc)}, status_code=500
320
+ )
321
+
322
+
323
+ @app.get("/logout", response_class=HTMLResponse)
324
+ async def logout(request: Request, response: Response):
325
+ await del_user_info_from_cookie(request=request, response=response)
326
+ response = RedirectResponse(url="/", status_code=302)
327
+ # Set cookies to empty values and expire them immediately
328
+ response.set_cookie(key="session_token", value="", expires=0)
329
+ response.set_cookie(key="X-User-Info", value="", expires=0)
330
+ return response
331
+
332
+
333
+ @app.get("/get-tokens-left")
334
+ async def get_tokens_left(request: Request):
335
+ try:
336
+ user_info = await get_user_info_from_cookie(request)
337
+ user_details = await get_user_details(user_info["email"])
338
+ await reset_tokens_for_user(user_details)
339
+ tokens_left = user_details.metadata["tokens_left"]
340
+ return {"tokens_left": tokens_left}
341
+ except Exception as e:
342
+ print(f"Error getting tokens left: {e}")
343
+ return {"tokens_left": 0}
344
+
345
+
346
+ mount_chainlit(app=app, target="main.py", path=CHAINLIT_PATH)
347
+
348
+ if __name__ == "__main__":
349
+ import uvicorn
350
+
351
+ uvicorn.run(app, host="127.0.0.1", port=7860)
code/chainlit.md CHANGED
@@ -1,10 +1,5 @@
1
  # Welcome to DL4DS Tutor! 🚀🤖
2
 
3
- Hi there, this is an LLM chatbot designed to help answer questions on the course content, built using Langchain and Chainlit.
4
- This is still very much a Work in Progress.
5
 
6
  ### --- Please wait while the Tutor loads... ---
7
-
8
- ## Useful Links 🔗
9
-
10
- - **Documentation:** [Chainlit Documentation](https://docs.chainlit.io) 📚
 
1
  # Welcome to DL4DS Tutor! 🚀🤖
2
 
3
+ Hi there, this is an LLM chatbot designed to help answer questions on the course content.
 
4
 
5
  ### --- Please wait while the Tutor loads... ---
 
 
 
 
code/chainlit_base.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit.data as cl_data
2
+ import asyncio
3
+ import yaml
4
+ from typing import Any, Dict, no_type_check
5
+ import chainlit as cl
6
+ from modules.chat.llm_tutor import LLMTutor
7
+ from modules.chat.helpers import (
8
+ get_sources,
9
+ get_history_chat_resume,
10
+ get_history_setup_llm,
11
+ get_last_config,
12
+ )
13
+ import copy
14
+ from chainlit.types import ThreadDict
15
+ import time
16
+ from langchain_community.callbacks import get_openai_callback
17
+
18
+ USER_TIMEOUT = 60_000
19
+ SYSTEM = "System"
20
+ LLM = "AI Tutor"
21
+ AGENT = "Agent"
22
+ YOU = "User"
23
+ ERROR = "Error"
24
+
25
+ with open("modules/config/config.yml", "r") as f:
26
+ config = yaml.safe_load(f)
27
+
28
+
29
+ # async def setup_data_layer():
30
+ # """
31
+ # Set up the data layer for chat logging.
32
+ # """
33
+ # if config["chat_logging"]["log_chat"]:
34
+ # data_layer = CustomLiteralDataLayer(
35
+ # api_key=LITERAL_API_KEY_LOGGING, server=LITERAL_API_URL
36
+ # )
37
+ # else:
38
+ # data_layer = None
39
+
40
+ # return data_layer
41
+
42
+
43
+ class Chatbot:
44
+ def __init__(self, config):
45
+ """
46
+ Initialize the Chatbot class.
47
+ """
48
+ self.config = config
49
+
50
+ async def _load_config(self):
51
+ """
52
+ Load the configuration from a YAML file.
53
+ """
54
+ with open("modules/config/config.yml", "r") as f:
55
+ return yaml.safe_load(f)
56
+
57
+ @no_type_check
58
+ async def setup_llm(self):
59
+ """
60
+ Set up the LLM with the provided settings. Update the configuration and initialize the LLM tutor.
61
+
62
+ #TODO: Clean this up.
63
+ """
64
+ start_time = time.time()
65
+
66
+ llm_settings = cl.user_session.get("llm_settings", {})
67
+ (
68
+ chat_profile,
69
+ retriever_method,
70
+ memory_window,
71
+ llm_style,
72
+ generate_follow_up,
73
+ chunking_mode,
74
+ ) = (
75
+ llm_settings.get("chat_model"),
76
+ llm_settings.get("retriever_method"),
77
+ llm_settings.get("memory_window"),
78
+ llm_settings.get("llm_style"),
79
+ llm_settings.get("follow_up_questions"),
80
+ llm_settings.get("chunking_mode"),
81
+ )
82
+
83
+ chain = cl.user_session.get("chain")
84
+ memory_list = cl.user_session.get(
85
+ "memory",
86
+ (
87
+ list(chain.store.values())[0].messages
88
+ if len(chain.store.values()) > 0
89
+ else []
90
+ ),
91
+ )
92
+ conversation_list = get_history_setup_llm(memory_list)
93
+
94
+ old_config = copy.deepcopy(self.config)
95
+ self.config["vectorstore"]["db_option"] = retriever_method
96
+ self.config["llm_params"]["memory_window"] = memory_window
97
+ self.config["llm_params"]["llm_style"] = llm_style
98
+ self.config["llm_params"]["llm_loader"] = chat_profile
99
+ self.config["llm_params"]["generate_follow_up"] = generate_follow_up
100
+ self.config["splitter_options"]["chunking_mode"] = chunking_mode
101
+
102
+ self.llm_tutor.update_llm(
103
+ old_config, self.config
104
+ ) # update only llm attributes that are changed
105
+ self.chain = self.llm_tutor.qa_bot(
106
+ memory=conversation_list,
107
+ )
108
+
109
+ cl.user_session.set("chain", self.chain)
110
+ cl.user_session.set("llm_tutor", self.llm_tutor)
111
+
112
+ print("Time taken to setup LLM: ", time.time() - start_time)
113
+
114
+ @no_type_check
115
+ async def update_llm(self, new_settings: Dict[str, Any]):
116
+ """
117
+ Update the LLM settings and reinitialize the LLM with the new settings.
118
+
119
+ Args:
120
+ new_settings (Dict[str, Any]): The new settings to update.
121
+ """
122
+ cl.user_session.set("llm_settings", new_settings)
123
+ await self.inform_llm_settings()
124
+ await self.setup_llm()
125
+
126
+ async def make_llm_settings_widgets(self, config=None):
127
+ """
128
+ Create and send the widgets for LLM settings configuration.
129
+
130
+ Args:
131
+ config: The configuration to use for setting up the widgets.
132
+ """
133
+ config = config or self.config
134
+ await cl.ChatSettings(
135
+ [
136
+ cl.input_widget.Select(
137
+ id="chat_model",
138
+ label="Model Name (Default GPT-3)",
139
+ values=["local_llm", "gpt-3.5-turbo-1106", "gpt-4", "gpt-4o-mini"],
140
+ initial_index=[
141
+ "local_llm",
142
+ "gpt-3.5-turbo-1106",
143
+ "gpt-4",
144
+ "gpt-4o-mini",
145
+ ].index(config["llm_params"]["llm_loader"]),
146
+ ),
147
+ cl.input_widget.Select(
148
+ id="retriever_method",
149
+ label="Retriever (Default FAISS)",
150
+ values=["FAISS", "Chroma", "RAGatouille", "RAPTOR"],
151
+ initial_index=["FAISS", "Chroma", "RAGatouille", "RAPTOR"].index(
152
+ config["vectorstore"]["db_option"]
153
+ ),
154
+ ),
155
+ cl.input_widget.Slider(
156
+ id="memory_window",
157
+ label="Memory Window (Default 3)",
158
+ initial=3,
159
+ min=0,
160
+ max=10,
161
+ step=1,
162
+ ),
163
+ cl.input_widget.Switch(
164
+ id="view_sources", label="View Sources", initial=False
165
+ ),
166
+ cl.input_widget.Switch(
167
+ id="stream_response",
168
+ label="Stream response",
169
+ initial=config["llm_params"]["stream"],
170
+ ),
171
+ cl.input_widget.Select(
172
+ id="chunking_mode",
173
+ label="Chunking mode",
174
+ values=["fixed", "semantic"],
175
+ initial_index=1,
176
+ ),
177
+ cl.input_widget.Switch(
178
+ id="follow_up_questions",
179
+ label="Generate follow up questions",
180
+ initial=False,
181
+ ),
182
+ cl.input_widget.Select(
183
+ id="llm_style",
184
+ label="Type of Conversation (Default Normal)",
185
+ values=["Normal", "ELI5"],
186
+ initial_index=0,
187
+ ),
188
+ ]
189
+ ).send()
190
+
191
+ @no_type_check
192
+ async def inform_llm_settings(self):
193
+ """
194
+ Inform the user about the updated LLM settings and display them as a message.
195
+ """
196
+ llm_settings: Dict[str, Any] = cl.user_session.get("llm_settings", {})
197
+ llm_tutor = cl.user_session.get("llm_tutor")
198
+ settings_dict = {
199
+ "model": llm_settings.get("chat_model"),
200
+ "retriever": llm_settings.get("retriever_method"),
201
+ "memory_window": llm_settings.get("memory_window"),
202
+ "num_docs_in_db": (
203
+ len(llm_tutor.vector_db)
204
+ if llm_tutor and hasattr(llm_tutor, "vector_db")
205
+ else 0
206
+ ),
207
+ "view_sources": llm_settings.get("view_sources"),
208
+ "follow_up_questions": llm_settings.get("follow_up_questions"),
209
+ }
210
+ print("Settings Dict: ", settings_dict)
211
+ await cl.Message(
212
+ author=SYSTEM,
213
+ content="LLM settings have been updated. You can continue with your Query!",
214
+ # elements=[
215
+ # cl.Text(
216
+ # name="settings",
217
+ # display="side",
218
+ # content=json.dumps(settings_dict, indent=4),
219
+ # language="json",
220
+ # ),
221
+ # ],
222
+ ).send()
223
+
224
+ async def set_starters(self):
225
+ """
226
+ Set starter messages for the chatbot.
227
+ """
228
+ # Return Starters only if the chat is new
229
+
230
+ try:
231
+ thread = cl_data._data_layer.get_thread(
232
+ cl.context.session.thread_id
233
+ ) # see if the thread has any steps
234
+ if thread.steps or len(thread.steps) > 0:
235
+ return None
236
+ except Exception as e:
237
+ print(e)
238
+ return [
239
+ cl.Starter(
240
+ label="recording on CNNs?",
241
+ message="Where can I find the recording for the lecture on Transformers?",
242
+ icon="/public/adv-screen-recorder-svgrepo-com.svg",
243
+ ),
244
+ cl.Starter(
245
+ label="where's the slides?",
246
+ message="When are the lectures? I can't find the schedule.",
247
+ icon="/public/alarmy-svgrepo-com.svg",
248
+ ),
249
+ cl.Starter(
250
+ label="Due Date?",
251
+ message="When is the final project due?",
252
+ icon="/public/calendar-samsung-17-svgrepo-com.svg",
253
+ ),
254
+ cl.Starter(
255
+ label="Explain backprop.",
256
+ message="I didn't understand the math behind backprop, could you explain it?",
257
+ icon="/public/acastusphoton-svgrepo-com.svg",
258
+ ),
259
+ ]
260
+
261
+ def rename(self, orig_author: str):
262
+ """
263
+ Rename the original author to a more user-friendly name.
264
+
265
+ Args:
266
+ orig_author (str): The original author's name.
267
+
268
+ Returns:
269
+ str: The renamed author.
270
+ """
271
+ rename_dict = {"Chatbot": LLM}
272
+ return rename_dict.get(orig_author, orig_author)
273
+
274
+ async def start(self, config=None):
275
+ """
276
+ Start the chatbot, initialize settings widgets,
277
+ and display and load previous conversation if chat logging is enabled.
278
+ """
279
+
280
+ start_time = time.time()
281
+
282
+ self.config = (
283
+ await self._load_config() if config is None else config
284
+ ) # Reload the configuration on chat resume
285
+
286
+ await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
287
+
288
+ user = cl.user_session.get("user")
289
+
290
+ # TODO: remove self.user with cl.user_session.get("user")
291
+ try:
292
+ self.user = {
293
+ "user_id": user.identifier,
294
+ "session_id": cl.context.session.thread_id,
295
+ }
296
+ except Exception as e:
297
+ print(e)
298
+ self.user = {
299
+ "user_id": "guest",
300
+ "session_id": cl.context.session.thread_id,
301
+ }
302
+
303
+ memory = cl.user_session.get("memory", [])
304
+ self.llm_tutor = LLMTutor(self.config, user=self.user)
305
+
306
+ self.chain = self.llm_tutor.qa_bot(
307
+ memory=memory,
308
+ )
309
+ self.question_generator = self.llm_tutor.question_generator
310
+ cl.user_session.set("llm_tutor", self.llm_tutor)
311
+ cl.user_session.set("chain", self.chain)
312
+
313
+ print("Time taken to start LLM: ", time.time() - start_time)
314
+
315
+ async def stream_response(self, response):
316
+ """
317
+ Stream the response from the LLM.
318
+
319
+ Args:
320
+ response: The response from the LLM.
321
+ """
322
+ msg = cl.Message(content="")
323
+ await msg.send()
324
+
325
+ output = {}
326
+ for chunk in response:
327
+ if "answer" in chunk:
328
+ await msg.stream_token(chunk["answer"])
329
+
330
+ for key in chunk:
331
+ if key not in output:
332
+ output[key] = chunk[key]
333
+ else:
334
+ output[key] += chunk[key]
335
+ return output
336
+
337
+ async def main(self, message):
338
+ """
339
+ Process and Display the Conversation.
340
+
341
+ Args:
342
+ message: The incoming chat message.
343
+ """
344
+
345
+ start_time = time.time()
346
+
347
+ chain = cl.user_session.get("chain")
348
+ token_count = 0 # initialize token count
349
+ if not chain:
350
+ await self.start() # start the chatbot if the chain is not present
351
+ chain = cl.user_session.get("chain")
352
+
353
+ # update user info with last message time
354
+ llm_settings = cl.user_session.get("llm_settings", {})
355
+ view_sources = llm_settings.get("view_sources", False)
356
+ stream = llm_settings.get("stream_response", False)
357
+ stream = False # Fix streaming
358
+ user_query_dict = {"input": message.content}
359
+ # Define the base configuration
360
+ cb = cl.AsyncLangchainCallbackHandler()
361
+ chain_config = {
362
+ "configurable": {
363
+ "user_id": self.user["user_id"],
364
+ "conversation_id": self.user["session_id"],
365
+ "memory_window": self.config["llm_params"]["memory_window"],
366
+ },
367
+ "callbacks": (
368
+ [cb]
369
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
370
+ else None
371
+ ),
372
+ }
373
+
374
+ with get_openai_callback() as token_count_cb:
375
+ if stream:
376
+ res = chain.stream(user_query=user_query_dict, config=chain_config)
377
+ res = await self.stream_response(res)
378
+ else:
379
+ res = await chain.invoke(
380
+ user_query=user_query_dict,
381
+ config=chain_config,
382
+ )
383
+ token_count += token_count_cb.total_tokens
384
+
385
+ answer = res.get("answer", res.get("result"))
386
+
387
+ answer_with_sources, source_elements, sources_dict = get_sources(
388
+ res, answer, stream=stream, view_sources=view_sources
389
+ )
390
+ answer_with_sources = answer_with_sources.replace("$$", "$")
391
+
392
+ print("Time taken to process the message: ", time.time() - start_time)
393
+
394
+ actions = []
395
+
396
+ if self.config["llm_params"]["generate_follow_up"]:
397
+ start_time = time.time()
398
+ cb_follow_up = cl.AsyncLangchainCallbackHandler()
399
+ config = {
400
+ "callbacks": (
401
+ [cb_follow_up]
402
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
403
+ else None
404
+ )
405
+ }
406
+ with get_openai_callback() as token_count_cb:
407
+ list_of_questions = await self.question_generator.generate_questions(
408
+ query=user_query_dict["input"],
409
+ response=answer,
410
+ chat_history=res.get("chat_history"),
411
+ context=res.get("context"),
412
+ config=config,
413
+ )
414
+
415
+ token_count += token_count_cb.total_tokens
416
+
417
+ for question in list_of_questions:
418
+ actions.append(
419
+ cl.Action(
420
+ name="follow up question",
421
+ value="example_value",
422
+ description=question,
423
+ label=question,
424
+ )
425
+ )
426
+
427
+ print("Time taken to generate questions: ", time.time() - start_time)
428
+ print("Total Tokens Used: ", token_count)
429
+
430
+ await cl.Message(
431
+ content=answer_with_sources,
432
+ elements=source_elements,
433
+ author=LLM,
434
+ actions=actions,
435
+ metadata=self.config,
436
+ ).send()
437
+
438
+ async def on_chat_resume(self, thread: ThreadDict):
439
+ thread_config = None
440
+ steps = thread["steps"]
441
+ k = self.config["llm_params"][
442
+ "memory_window"
443
+ ] # on resume, alwyas use the default memory window
444
+ conversation_list = get_history_chat_resume(steps, k, SYSTEM, LLM)
445
+ thread_config = get_last_config(
446
+ steps
447
+ ) # TODO: Returns None for now - which causes config to be reloaded with default values
448
+ cl.user_session.set("memory", conversation_list)
449
+ await self.start(config=thread_config)
450
+
451
+ async def on_follow_up(self, action: cl.Action):
452
+ user = cl.user_session.get("user")
453
+ message = await cl.Message(
454
+ content=action.description,
455
+ type="user_message",
456
+ author=user.identifier,
457
+ ).send()
458
+ async with cl.Step(
459
+ name="on_follow_up", type="run", parent_id=message.id
460
+ ) as step:
461
+ await self.main(message)
462
+ step.output = message.content
463
+
464
+
465
+ chatbot = Chatbot(config=config)
466
+
467
+
468
+ async def start_app():
469
+ # cl_data._data_layer = await setup_data_layer()
470
+ # chatbot.literal_client = cl_data._data_layer.client if cl_data._data_layer else None
471
+ cl.set_starters(chatbot.set_starters)
472
+ cl.author_rename(chatbot.rename)
473
+ cl.on_chat_start(chatbot.start)
474
+ cl.on_chat_resume(chatbot.on_chat_resume)
475
+ cl.on_message(chatbot.main)
476
+ cl.on_settings_update(chatbot.update_llm)
477
+ cl.action_callback("follow up question")(chatbot.on_follow_up)
478
+
479
+
480
+ loop = asyncio.get_event_loop()
481
+ if loop.is_running():
482
+ asyncio.ensure_future(start_app())
483
+ else:
484
+ asyncio.run(start_app())
code/main.py CHANGED
@@ -1,15 +1,12 @@
1
  import chainlit.data as cl_data
2
  import asyncio
3
  from modules.config.constants import (
4
- LLAMA_PATH,
5
  LITERAL_API_KEY_LOGGING,
6
  LITERAL_API_URL,
7
  )
8
  from modules.chat_processor.literal_ai import CustomLiteralDataLayer
9
-
10
  import json
11
  import yaml
12
- import os
13
  from typing import Any, Dict, no_type_check
14
  import chainlit as cl
15
  from modules.chat.llm_tutor import LLMTutor
@@ -19,17 +16,27 @@ from modules.chat.helpers import (
19
  get_history_setup_llm,
20
  get_last_config,
21
  )
 
 
 
 
 
 
 
22
  import copy
23
  from typing import Optional
24
  from chainlit.types import ThreadDict
25
  import time
 
 
 
26
 
27
  USER_TIMEOUT = 60_000
28
- SYSTEM = "System 🖥️"
29
- LLM = "LLM 🧠"
30
- AGENT = "Agent <>"
31
- YOU = "You 😃"
32
- ERROR = "Error 🚫"
33
 
34
  with open("modules/config/config.yml", "r") as f:
35
  config = yaml.safe_load(f)
@@ -49,6 +56,24 @@ async def setup_data_layer():
49
  return data_layer
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class Chatbot:
53
  def __init__(self, config):
54
  """
@@ -73,7 +98,14 @@ class Chatbot:
73
  start_time = time.time()
74
 
75
  llm_settings = cl.user_session.get("llm_settings", {})
76
- chat_profile, retriever_method, memory_window, llm_style, generate_follow_up, chunking_mode = (
 
 
 
 
 
 
 
77
  llm_settings.get("chat_model"),
78
  llm_settings.get("retriever_method"),
79
  llm_settings.get("memory_window"),
@@ -106,15 +138,8 @@ class Chatbot:
106
  ) # update only llm attributes that are changed
107
  self.chain = self.llm_tutor.qa_bot(
108
  memory=conversation_list,
109
- callbacks=(
110
- [cl.LangchainCallbackHandler()]
111
- if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
112
- else None
113
- ),
114
  )
115
 
116
- tags = [chat_profile, self.config["vectorstore"]["db_option"]]
117
-
118
  cl.user_session.set("chain", self.chain)
119
  cl.user_session.set("llm_tutor", self.llm_tutor)
120
 
@@ -180,7 +205,7 @@ class Chatbot:
180
  cl.input_widget.Select(
181
  id="chunking_mode",
182
  label="Chunking mode",
183
- values=['fixed', 'semantic'],
184
  initial_index=1,
185
  ),
186
  cl.input_widget.Switch(
@@ -216,17 +241,18 @@ class Chatbot:
216
  "view_sources": llm_settings.get("view_sources"),
217
  "follow_up_questions": llm_settings.get("follow_up_questions"),
218
  }
 
219
  await cl.Message(
220
  author=SYSTEM,
221
  content="LLM settings have been updated. You can continue with your Query!",
222
- elements=[
223
- cl.Text(
224
- name="settings",
225
- display="side",
226
- content=json.dumps(settings_dict, indent=4),
227
- language="json",
228
- ),
229
- ],
230
  ).send()
231
 
232
  async def set_starters(self):
@@ -241,7 +267,8 @@ class Chatbot:
241
  ) # see if the thread has any steps
242
  if thread.steps or len(thread.steps) > 0:
243
  return None
244
- except:
 
245
  return [
246
  cl.Starter(
247
  label="recording on CNNs?",
@@ -275,7 +302,7 @@ class Chatbot:
275
  Returns:
276
  str: The renamed author.
277
  """
278
- rename_dict = {"Chatbot": "AI Tutor"}
279
  return rename_dict.get(orig_author, orig_author)
280
 
281
  async def start(self, config=None):
@@ -292,25 +319,26 @@ class Chatbot:
292
 
293
  await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
294
 
295
- await self.make_llm_settings_widgets(self.config)
296
  user = cl.user_session.get("user")
297
- self.user = {
298
- "user_id": user.identifier,
299
- "session_id": cl.context.session.thread_id,
300
- }
301
 
302
- memory = cl.user_session.get("memory", [])
 
 
 
 
 
 
 
 
 
 
 
303
 
304
- cl.user_session.set("user", self.user)
305
  self.llm_tutor = LLMTutor(self.config, user=self.user)
306
 
307
  self.chain = self.llm_tutor.qa_bot(
308
  memory=memory,
309
- callbacks=(
310
- [cl.LangchainCallbackHandler()]
311
- if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
312
- else None
313
- ),
314
  )
315
  self.question_generator = self.llm_tutor.question_generator
316
  cl.user_session.set("llm_tutor", self.llm_tutor)
@@ -351,29 +379,98 @@ class Chatbot:
351
  start_time = time.time()
352
 
353
  chain = cl.user_session.get("chain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
 
355
  llm_settings = cl.user_session.get("llm_settings", {})
356
  view_sources = llm_settings.get("view_sources", False)
357
  stream = llm_settings.get("stream_response", False)
358
- steam = False # Fix streaming
359
  user_query_dict = {"input": message.content}
360
  # Define the base configuration
 
361
  chain_config = {
362
  "configurable": {
363
  "user_id": self.user["user_id"],
364
  "conversation_id": self.user["session_id"],
365
  "memory_window": self.config["llm_params"]["memory_window"],
366
- }
 
 
 
 
 
367
  }
368
 
369
- if stream:
370
- res = chain.stream(user_query=user_query_dict, config=chain_config)
371
- res = await self.stream_response(res)
372
- else:
373
- res = await chain.invoke(
374
- user_query=user_query_dict,
375
- config=chain_config,
376
- )
 
 
377
 
378
  answer = res.get("answer", res.get("result"))
379
 
@@ -388,15 +485,26 @@ class Chatbot:
388
 
389
  if self.config["llm_params"]["generate_follow_up"]:
390
  start_time = time.time()
391
- list_of_questions = self.question_generator.generate_questions(
392
- query=user_query_dict["input"],
393
- response=answer,
394
- chat_history=res.get("chat_history"),
395
- context=res.get("context"),
396
- )
 
 
 
 
 
 
 
 
 
 
397
 
398
- for question in list_of_questions:
399
 
 
400
  actions.append(
401
  cl.Action(
402
  name="follow up question",
@@ -408,6 +516,15 @@ class Chatbot:
408
 
409
  print("Time taken to generate questions: ", time.time() - start_time)
410
 
 
 
 
 
 
 
 
 
 
411
  await cl.Message(
412
  content=answer_with_sources,
413
  elements=source_elements,
@@ -429,22 +546,46 @@ class Chatbot:
429
  cl.user_session.set("memory", conversation_list)
430
  await self.start(config=thread_config)
431
 
432
- @cl.oauth_callback
433
- def auth_callback(
434
- provider_id: str,
435
- token: str,
436
- raw_user_data: Dict[str, str],
437
- default_user: cl.User,
438
- ) -> Optional[cl.User]:
439
- return default_user
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
 
441
  async def on_follow_up(self, action: cl.Action):
 
442
  message = await cl.Message(
443
  content=action.description,
444
  type="user_message",
445
- author=self.user["user_id"],
446
  ).send()
447
- await self.main(message)
 
 
 
 
448
 
449
 
450
  chatbot = Chatbot(config=config)
@@ -462,4 +603,8 @@ async def start_app():
462
  cl.action_callback("follow up question")(chatbot.on_follow_up)
463
 
464
 
465
- asyncio.run(start_app())
 
 
 
 
 
1
  import chainlit.data as cl_data
2
  import asyncio
3
  from modules.config.constants import (
 
4
  LITERAL_API_KEY_LOGGING,
5
  LITERAL_API_URL,
6
  )
7
  from modules.chat_processor.literal_ai import CustomLiteralDataLayer
 
8
  import json
9
  import yaml
 
10
  from typing import Any, Dict, no_type_check
11
  import chainlit as cl
12
  from modules.chat.llm_tutor import LLMTutor
 
16
  get_history_setup_llm,
17
  get_last_config,
18
  )
19
+ from modules.chat_processor.helpers import (
20
+ update_user_info,
21
+ get_time,
22
+ check_user_cooldown,
23
+ reset_tokens_for_user,
24
+ get_user_details,
25
+ )
26
  import copy
27
  from typing import Optional
28
  from chainlit.types import ThreadDict
29
  import time
30
+ import base64
31
+ from langchain_community.callbacks import get_openai_callback
32
+ from datetime import datetime, timezone
33
 
34
  USER_TIMEOUT = 60_000
35
+ SYSTEM = "System"
36
+ LLM = "AI Tutor"
37
+ AGENT = "Agent"
38
+ YOU = "User"
39
+ ERROR = "Error"
40
 
41
  with open("modules/config/config.yml", "r") as f:
42
  config = yaml.safe_load(f)
 
56
  return data_layer
57
 
58
 
59
+ async def update_user_from_chainlit(user, token_count=0):
60
+ if "admin" not in user.metadata["role"]:
61
+ user.metadata["tokens_left"] = user.metadata["tokens_left"] - token_count
62
+ user.metadata["all_time_tokens_allocated"] = (
63
+ user.metadata["all_time_tokens_allocated"] - token_count
64
+ )
65
+ user.metadata["tokens_left_at_last_message"] = user.metadata[
66
+ "tokens_left"
67
+ ] # tokens_left will keep regenerating outside of chainlit
68
+ user.metadata["last_message_time"] = get_time()
69
+ await update_user_info(user)
70
+
71
+ tokens_left = user.metadata["tokens_left"]
72
+ if tokens_left < 0:
73
+ tokens_left = 0
74
+ return tokens_left
75
+
76
+
77
  class Chatbot:
78
  def __init__(self, config):
79
  """
 
98
  start_time = time.time()
99
 
100
  llm_settings = cl.user_session.get("llm_settings", {})
101
+ (
102
+ chat_profile,
103
+ retriever_method,
104
+ memory_window,
105
+ llm_style,
106
+ generate_follow_up,
107
+ chunking_mode,
108
+ ) = (
109
  llm_settings.get("chat_model"),
110
  llm_settings.get("retriever_method"),
111
  llm_settings.get("memory_window"),
 
138
  ) # update only llm attributes that are changed
139
  self.chain = self.llm_tutor.qa_bot(
140
  memory=conversation_list,
 
 
 
 
 
141
  )
142
 
 
 
143
  cl.user_session.set("chain", self.chain)
144
  cl.user_session.set("llm_tutor", self.llm_tutor)
145
 
 
205
  cl.input_widget.Select(
206
  id="chunking_mode",
207
  label="Chunking mode",
208
+ values=["fixed", "semantic"],
209
  initial_index=1,
210
  ),
211
  cl.input_widget.Switch(
 
241
  "view_sources": llm_settings.get("view_sources"),
242
  "follow_up_questions": llm_settings.get("follow_up_questions"),
243
  }
244
+ print("Settings Dict: ", settings_dict)
245
  await cl.Message(
246
  author=SYSTEM,
247
  content="LLM settings have been updated. You can continue with your Query!",
248
+ # elements=[
249
+ # cl.Text(
250
+ # name="settings",
251
+ # display="side",
252
+ # content=json.dumps(settings_dict, indent=4),
253
+ # language="json",
254
+ # ),
255
+ # ],
256
  ).send()
257
 
258
  async def set_starters(self):
 
267
  ) # see if the thread has any steps
268
  if thread.steps or len(thread.steps) > 0:
269
  return None
270
+ except Exception as e:
271
+ print(e)
272
  return [
273
  cl.Starter(
274
  label="recording on CNNs?",
 
302
  Returns:
303
  str: The renamed author.
304
  """
305
+ rename_dict = {"Chatbot": LLM}
306
  return rename_dict.get(orig_author, orig_author)
307
 
308
  async def start(self, config=None):
 
319
 
320
  await self.make_llm_settings_widgets(self.config) # Reload the settings widgets
321
 
 
322
  user = cl.user_session.get("user")
 
 
 
 
323
 
324
+ # TODO: remove self.user with cl.user_session.get("user")
325
+ try:
326
+ self.user = {
327
+ "user_id": user.identifier,
328
+ "session_id": cl.context.session.thread_id,
329
+ }
330
+ except Exception as e:
331
+ print(e)
332
+ self.user = {
333
+ "user_id": "guest",
334
+ "session_id": cl.context.session.thread_id,
335
+ }
336
 
337
+ memory = cl.user_session.get("memory", [])
338
  self.llm_tutor = LLMTutor(self.config, user=self.user)
339
 
340
  self.chain = self.llm_tutor.qa_bot(
341
  memory=memory,
 
 
 
 
 
342
  )
343
  self.question_generator = self.llm_tutor.question_generator
344
  cl.user_session.set("llm_tutor", self.llm_tutor)
 
379
  start_time = time.time()
380
 
381
  chain = cl.user_session.get("chain")
382
+ token_count = 0 # initialize token count
383
+ if not chain:
384
+ await self.start() # start the chatbot if the chain is not present
385
+ chain = cl.user_session.get("chain")
386
+
387
+ # update user info with last message time
388
+ user = cl.user_session.get("user")
389
+ await reset_tokens_for_user(user)
390
+ updated_user = await get_user_details(user.identifier)
391
+ user.metadata = updated_user.metadata
392
+ cl.user_session.set("user", user)
393
+
394
+ print("\n\n User Tokens Left: ", user.metadata["tokens_left"])
395
+
396
+ # see if user has token credits left
397
+ # if not, return message saying they have run out of tokens
398
+ if user.metadata["tokens_left"] <= 0 and "admin" not in user.metadata["role"]:
399
+ current_datetime = get_time()
400
+ cooldown, cooldown_end_time = await check_user_cooldown(
401
+ user, current_datetime
402
+ )
403
+ if cooldown:
404
+ # get time left in cooldown
405
+ # convert both to datetime objects
406
+ cooldown_end_time = datetime.fromisoformat(cooldown_end_time).replace(
407
+ tzinfo=timezone.utc
408
+ )
409
+ current_datetime = datetime.fromisoformat(current_datetime).replace(
410
+ tzinfo=timezone.utc
411
+ )
412
+ cooldown_time_left = cooldown_end_time - current_datetime
413
+ # Get the total seconds
414
+ total_seconds = int(cooldown_time_left.total_seconds())
415
+ # Calculate hours, minutes, and seconds
416
+ hours, remainder = divmod(total_seconds, 3600)
417
+ minutes, seconds = divmod(remainder, 60)
418
+ # Format the time as 00 hrs 00 mins 00 secs
419
+ formatted_time = f"{hours:02} hrs {minutes:02} mins {seconds:02} secs"
420
+ await cl.Message(
421
+ content=(
422
+ "Ah, seems like you have run out of tokens...Click "
423
+ '<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> for more info. Please come back after {}'.format(
424
+ formatted_time
425
+ )
426
+ ),
427
+ author=SYSTEM,
428
+ ).send()
429
+ user.metadata["in_cooldown"] = True
430
+ await update_user_info(user)
431
+ return
432
+ else:
433
+ await cl.Message(
434
+ content=(
435
+ "Ah, seems like you don't have any tokens left...Please wait while we regenerate your tokens. Click "
436
+ '<a href="/cooldown" style="color: #0000CD; text-decoration: none;" target="_self">here</a> to view your token credits.'
437
+ ),
438
+ author=SYSTEM,
439
+ ).send()
440
+ return
441
+
442
+ user.metadata["in_cooldown"] = False
443
 
444
  llm_settings = cl.user_session.get("llm_settings", {})
445
  view_sources = llm_settings.get("view_sources", False)
446
  stream = llm_settings.get("stream_response", False)
447
+ stream = False # Fix streaming
448
  user_query_dict = {"input": message.content}
449
  # Define the base configuration
450
+ cb = cl.AsyncLangchainCallbackHandler()
451
  chain_config = {
452
  "configurable": {
453
  "user_id": self.user["user_id"],
454
  "conversation_id": self.user["session_id"],
455
  "memory_window": self.config["llm_params"]["memory_window"],
456
+ },
457
+ "callbacks": (
458
+ [cb]
459
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
460
+ else None
461
+ ),
462
  }
463
 
464
+ with get_openai_callback() as token_count_cb:
465
+ if stream:
466
+ res = chain.stream(user_query=user_query_dict, config=chain_config)
467
+ res = await self.stream_response(res)
468
+ else:
469
+ res = await chain.invoke(
470
+ user_query=user_query_dict,
471
+ config=chain_config,
472
+ )
473
+ token_count += token_count_cb.total_tokens
474
 
475
  answer = res.get("answer", res.get("result"))
476
 
 
485
 
486
  if self.config["llm_params"]["generate_follow_up"]:
487
  start_time = time.time()
488
+ cb_follow_up = cl.AsyncLangchainCallbackHandler()
489
+ config = {
490
+ "callbacks": (
491
+ [cb_follow_up]
492
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
493
+ else None
494
+ )
495
+ }
496
+ with get_openai_callback() as token_count_cb:
497
+ list_of_questions = await self.question_generator.generate_questions(
498
+ query=user_query_dict["input"],
499
+ response=answer,
500
+ chat_history=res.get("chat_history"),
501
+ context=res.get("context"),
502
+ config=config,
503
+ )
504
 
505
+ token_count += token_count_cb.total_tokens
506
 
507
+ for question in list_of_questions:
508
  actions.append(
509
  cl.Action(
510
  name="follow up question",
 
516
 
517
  print("Time taken to generate questions: ", time.time() - start_time)
518
 
519
+ # # update user info with token count
520
+ tokens_left = await update_user_from_chainlit(user, token_count)
521
+
522
+ answer_with_sources += (
523
+ '\n\n<footer><span style="font-size: 0.8em; text-align: right; display: block;">Tokens Left: '
524
+ + str(tokens_left)
525
+ + "</span></footer>\n"
526
+ )
527
+
528
  await cl.Message(
529
  content=answer_with_sources,
530
  elements=source_elements,
 
546
  cl.user_session.set("memory", conversation_list)
547
  await self.start(config=thread_config)
548
 
549
+ @cl.header_auth_callback
550
+ def header_auth_callback(headers: dict) -> Optional[cl.User]:
551
+ print("\n\n\nI am here\n\n\n")
552
+ # try: # TODO: Add try-except block after testing
553
+ # TODO: Implement to get the user information from the headers (not the cookie)
554
+ cookie = headers.get("cookie") # gets back a str
555
+ # Create a dictionary from the pairs
556
+ cookie_dict = {}
557
+ for pair in cookie.split("; "):
558
+ key, value = pair.split("=", 1)
559
+ # Strip surrounding quotes if present
560
+ cookie_dict[key] = value.strip('"')
561
+
562
+ decoded_user_info = base64.b64decode(
563
+ cookie_dict.get("X-User-Info", "")
564
+ ).decode()
565
+ decoded_user_info = json.loads(decoded_user_info)
566
+
567
+ print(
568
+ f"\n\n USER ROLE: {decoded_user_info['literalai_info']['metadata']['role']} \n\n"
569
+ )
570
+
571
+ return cl.User(
572
+ id=decoded_user_info["literalai_info"]["id"],
573
+ identifier=decoded_user_info["literalai_info"]["identifier"],
574
+ metadata=decoded_user_info["literalai_info"]["metadata"],
575
+ )
576
 
577
  async def on_follow_up(self, action: cl.Action):
578
+ user = cl.user_session.get("user")
579
  message = await cl.Message(
580
  content=action.description,
581
  type="user_message",
582
+ author=user.identifier,
583
  ).send()
584
+ async with cl.Step(
585
+ name="on_follow_up", type="run", parent_id=message.id
586
+ ) as step:
587
+ await self.main(message)
588
+ step.output = message.content
589
 
590
 
591
  chatbot = Chatbot(config=config)
 
603
  cl.action_callback("follow up question")(chatbot.on_follow_up)
604
 
605
 
606
+ loop = asyncio.get_event_loop()
607
+ if loop.is_running():
608
+ asyncio.ensure_future(start_app())
609
+ else:
610
+ asyncio.run(start_app())
code/modules/chat/chat_model_loader.py CHANGED
@@ -1,15 +1,8 @@
1
  from langchain_openai import ChatOpenAI
2
- from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
3
- from transformers import AutoTokenizer, TextStreamer
4
  from langchain_community.llms import LlamaCpp
5
- import torch
6
- import transformers
7
  import os
8
  from pathlib import Path
9
  from huggingface_hub import hf_hub_download
10
- from langchain.callbacks.manager import CallbackManager
11
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
12
- from modules.config.constants import LLAMA_PATH
13
 
14
 
15
  class ChatModelLoader:
@@ -35,10 +28,10 @@ class ChatModelLoader:
35
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
36
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
37
  model_path = self._verify_model_cache(
38
- self.config["llm_params"]["local_llm_params"]["model"]
39
  )
40
  llm = LlamaCpp(
41
- model_path=LLAMA_PATH,
42
  n_batch=n_batch,
43
  n_ctx=2048,
44
  f16_kv=True,
 
1
  from langchain_openai import ChatOpenAI
 
 
2
  from langchain_community.llms import LlamaCpp
 
 
3
  import os
4
  from pathlib import Path
5
  from huggingface_hub import hf_hub_download
 
 
 
6
 
7
 
8
  class ChatModelLoader:
 
28
  elif self.config["llm_params"]["llm_loader"] == "local_llm":
29
  n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
30
  model_path = self._verify_model_cache(
31
+ self.config["llm_params"]["local_llm_params"]["model_path"]
32
  )
33
  llm = LlamaCpp(
34
+ model_path=model_path,
35
  n_batch=n_batch,
36
  n_ctx=2048,
37
  f16_kv=True,
code/modules/chat/helpers.py CHANGED
@@ -42,7 +42,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
42
  full_answer += answer
43
 
44
  if view_sources:
45
-
46
  # Then, display the sources
47
  # check if the answer has sources
48
  if len(source_dict) == 0:
@@ -51,7 +50,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
51
  else:
52
  full_answer += "\n\n**Sources:**\n"
53
  for idx, (url_name, source_data) in enumerate(source_dict.items()):
54
-
55
  full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
56
 
57
  name = f"Source {idx + 1} Text\n"
@@ -110,6 +108,7 @@ def get_prompt(config, prompt_type):
110
  return prompts["openai"]["rephrase_prompt"]
111
 
112
 
 
113
  def get_history_chat_resume(steps, k, SYSTEM, LLM):
114
  conversation_list = []
115
  count = 0
@@ -119,14 +118,17 @@ def get_history_chat_resume(steps, k, SYSTEM, LLM):
119
  conversation_list.append(
120
  {"type": "user_message", "content": step["output"]}
121
  )
 
122
  elif step["type"] == "assistant_message":
123
  if step["name"] == LLM:
124
  conversation_list.append(
125
  {"type": "ai_message", "content": step["output"]}
126
  )
 
127
  else:
128
- raise ValueError("Invalid message type")
129
- count += 1
 
130
  if count >= 2 * k: # 2 * k to account for both user and assistant messages
131
  break
132
  conversation_list = conversation_list[::-1]
 
42
  full_answer += answer
43
 
44
  if view_sources:
 
45
  # Then, display the sources
46
  # check if the answer has sources
47
  if len(source_dict) == 0:
 
50
  else:
51
  full_answer += "\n\n**Sources:**\n"
52
  for idx, (url_name, source_data) in enumerate(source_dict.items()):
 
53
  full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
54
 
55
  name = f"Source {idx + 1} Text\n"
 
108
  return prompts["openai"]["rephrase_prompt"]
109
 
110
 
111
+ # TODO: Do this better
112
  def get_history_chat_resume(steps, k, SYSTEM, LLM):
113
  conversation_list = []
114
  count = 0
 
118
  conversation_list.append(
119
  {"type": "user_message", "content": step["output"]}
120
  )
121
+ count += 1
122
  elif step["type"] == "assistant_message":
123
  if step["name"] == LLM:
124
  conversation_list.append(
125
  {"type": "ai_message", "content": step["output"]}
126
  )
127
+ count += 1
128
  else:
129
+ pass
130
+ # raise ValueError("Invalid message type")
131
+ # count += 1
132
  if count >= 2 * k: # 2 * k to account for both user and assistant messages
133
  break
134
  conversation_list = conversation_list[::-1]
code/modules/chat/langchain/__init__.py ADDED
File without changes
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -1,20 +1,24 @@
1
  from langchain_core.prompts import ChatPromptTemplate
2
 
3
- from modules.chat.langchain.utils import *
4
- from langchain.memory import ChatMessageHistory
5
  from modules.chat.base import BaseRAG
6
  from langchain_core.prompts import PromptTemplate
7
- from langchain.memory import (
8
- ConversationBufferWindowMemory,
9
- ConversationSummaryBufferMemory,
 
 
 
 
 
 
 
 
10
  )
11
 
12
- import chainlit as cl
13
- from langchain_community.chat_models import ChatOpenAI
14
-
15
 
16
  class Langchain_RAG_V1(BaseRAG):
17
-
18
  def __init__(
19
  self,
20
  llm,
@@ -95,8 +99,8 @@ class QuestionGenerator:
95
  def __init__(self):
96
  pass
97
 
98
- def generate_questions(self, query, response, chat_history, context):
99
- questions = return_questions(query, response, chat_history, context)
100
  return questions
101
 
102
 
@@ -199,7 +203,7 @@ class Langchain_RAG_V2(BaseRAG):
199
  is_shared=True,
200
  ),
201
  ],
202
- )
203
 
204
  if callbacks is not None:
205
  self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
 
1
  from langchain_core.prompts import ChatPromptTemplate
2
 
3
+ # from modules.chat.langchain.utils import
4
+ from langchain_community.chat_message_histories import ChatMessageHistory
5
  from modules.chat.base import BaseRAG
6
  from langchain_core.prompts import PromptTemplate
7
+ from langchain.memory import ConversationBufferWindowMemory
8
+ from langchain_core.runnables.utils import ConfigurableFieldSpec
9
+ from .utils import (
10
+ CustomConversationalRetrievalChain,
11
+ create_history_aware_retriever,
12
+ create_stuff_documents_chain,
13
+ create_retrieval_chain,
14
+ return_questions,
15
+ CustomRunnableWithHistory,
16
+ BaseChatMessageHistory,
17
+ InMemoryHistory,
18
  )
19
 
 
 
 
20
 
21
  class Langchain_RAG_V1(BaseRAG):
 
22
  def __init__(
23
  self,
24
  llm,
 
99
  def __init__(self):
100
  pass
101
 
102
+ def generate_questions(self, query, response, chat_history, context, config):
103
+ questions = return_questions(query, response, chat_history, context, config)
104
  return questions
105
 
106
 
 
203
  is_shared=True,
204
  ),
205
  ],
206
+ ).with_config(run_name="Langchain_RAG_V2")
207
 
208
  if callbacks is not None:
209
  self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
code/modules/chat/langchain/utils.py CHANGED
@@ -1,56 +1,31 @@
1
  from typing import Any, Dict, List, Union, Tuple, Optional
2
- from langchain_core.messages import (
3
- BaseMessage,
4
- AIMessage,
5
- FunctionMessage,
6
- HumanMessage,
7
- )
8
-
9
  from langchain_core.prompts.base import BasePromptTemplate, format_document
10
- from langchain_core.prompts.chat import MessagesPlaceholder
11
  from langchain_core.output_parsers import StrOutputParser
12
  from langchain_core.output_parsers.base import BaseOutputParser
13
  from langchain_core.retrievers import BaseRetriever, RetrieverOutput
14
  from langchain_core.language_models import LanguageModelLike
15
  from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
16
  from langchain_core.runnables.history import RunnableWithMessageHistory
17
- from langchain_core.runnables.utils import ConfigurableFieldSpec
18
  from langchain_core.chat_history import BaseChatMessageHistory
19
  from langchain_core.pydantic_v1 import BaseModel, Field
20
  from langchain.chains.combine_documents.base import (
21
  DEFAULT_DOCUMENT_PROMPT,
22
  DEFAULT_DOCUMENT_SEPARATOR,
23
  DOCUMENTS_KEY,
24
- BaseCombineDocumentsChain,
25
  _validate_prompt,
26
  )
27
- from langchain.chains.llm import LLMChain
28
- from langchain_core.callbacks import Callbacks
29
- from langchain_core.documents import Document
30
-
31
-
32
- CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
33
-
34
  from langchain_core.runnables.config import RunnableConfig
35
- from langchain_core.messages import BaseMessage
36
-
37
-
38
- from langchain_core.output_parsers import StrOutputParser
39
  from langchain_core.prompts import ChatPromptTemplate
40
  from langchain_community.chat_models import ChatOpenAI
41
-
42
- from langchain.chains import RetrievalQA, ConversationalRetrievalChain
43
- from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
44
-
45
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
46
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
47
  import inspect
48
- from langchain.chains.conversational_retrieval.base import _get_chat_history
49
  from langchain_core.messages import BaseMessage
50
 
 
51
 
52
- class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
53
 
 
54
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
55
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
56
  buffer = ""
@@ -163,7 +138,6 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
163
 
164
 
165
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
166
-
167
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
168
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
169
  buffer = ""
@@ -304,8 +278,8 @@ def create_retrieval_chain(
304
  return retrieval_chain
305
 
306
 
307
- def return_questions(query, response, chat_history_str, context):
308
-
309
  system = (
310
  "You are someone that suggests a question based on the student's input and chat history. "
311
  "Generate a question that is relevant to the student's input and chat history. "
@@ -322,18 +296,22 @@ def return_questions(query, response, chat_history_str, context):
322
  prompt = ChatPromptTemplate.from_messages(
323
  [
324
  ("system", system),
325
- ("human", "{chat_history_str}, {context}, {query}, {response}"),
326
  ]
327
  )
328
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
329
  question_generator = prompt | llm | StrOutputParser()
330
- new_questions = question_generator.invoke(
 
 
 
331
  {
332
  "chat_history_str": chat_history_str,
333
  "context": context,
334
  "query": query,
335
  "response": response,
336
- }
 
337
  )
338
 
339
  list_of_questions = new_questions.split("...")
 
1
  from typing import Any, Dict, List, Union, Tuple, Optional
 
 
 
 
 
 
 
2
  from langchain_core.prompts.base import BasePromptTemplate, format_document
 
3
  from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.output_parsers.base import BaseOutputParser
5
  from langchain_core.retrievers import BaseRetriever, RetrieverOutput
6
  from langchain_core.language_models import LanguageModelLike
7
  from langchain_core.runnables import Runnable, RunnableBranch, RunnablePassthrough
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
9
  from langchain_core.chat_history import BaseChatMessageHistory
10
  from langchain_core.pydantic_v1 import BaseModel, Field
11
  from langchain.chains.combine_documents.base import (
12
  DEFAULT_DOCUMENT_PROMPT,
13
  DEFAULT_DOCUMENT_SEPARATOR,
14
  DOCUMENTS_KEY,
 
15
  _validate_prompt,
16
  )
 
 
 
 
 
 
 
17
  from langchain_core.runnables.config import RunnableConfig
 
 
 
 
18
  from langchain_core.prompts import ChatPromptTemplate
19
  from langchain_community.chat_models import ChatOpenAI
20
+ from langchain.chains import ConversationalRetrievalChain
 
 
 
 
21
  from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
22
  import inspect
 
23
  from langchain_core.messages import BaseMessage
24
 
25
+ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
26
 
 
27
 
28
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
30
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
31
  buffer = ""
 
138
 
139
 
140
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
 
141
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
142
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
143
  buffer = ""
 
278
  return retrieval_chain
279
 
280
 
281
+ # TODO: Remove Hard-coded values
282
+ async def return_questions(query, response, chat_history_str, context, config):
283
  system = (
284
  "You are someone that suggests a question based on the student's input and chat history. "
285
  "Generate a question that is relevant to the student's input and chat history. "
 
296
  prompt = ChatPromptTemplate.from_messages(
297
  [
298
  ("system", system),
299
+ # ("human", "{chat_history_str}, {context}, {query}, {response}"),
300
  ]
301
  )
302
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
303
  question_generator = prompt | llm | StrOutputParser()
304
+ question_generator = question_generator.with_config(
305
+ run_name="follow_up_question_generator"
306
+ )
307
+ new_questions = await question_generator.ainvoke(
308
  {
309
  "chat_history_str": chat_history_str,
310
  "context": context,
311
  "query": query,
312
  "response": response,
313
+ },
314
+ config=config,
315
  )
316
 
317
  list_of_questions = new_questions.split("...")
code/modules/chat/llm_tutor.py CHANGED
@@ -3,7 +3,6 @@ from modules.chat.chat_model_loader import ChatModelLoader
3
  from modules.vectorstore.store_manager import VectorStoreManager
4
  from modules.retriever.retriever import Retriever
5
  from modules.chat.langchain.langchain_rag import (
6
- Langchain_RAG_V1,
7
  Langchain_RAG_V2,
8
  QuestionGenerator,
9
  )
@@ -28,9 +27,11 @@ class LLMTutor:
28
  self.rephrase_prompt = get_prompt(
29
  config, "rephrase"
30
  ) # Initialize rephrase_prompt
31
- if self.config["vectorstore"]["embedd_files"]:
32
- self.vector_db.create_database()
33
- self.vector_db.save_database()
 
 
34
 
35
  def update_llm(self, old_config, new_config):
36
  """
@@ -48,9 +49,11 @@ class LLMTutor:
48
  self.vector_db = VectorStoreManager(
49
  self.config, logger=self.logger
50
  ).load_database() # Reinitialize VectorStoreManager if vectorstore changes
51
- if self.config["vectorstore"]["embedd_files"]:
52
- self.vector_db.create_database()
53
- self.vector_db.save_database()
 
 
54
 
55
  if "llm_params.llm_style" in changes:
56
  self.qa_prompt = get_prompt(
 
3
  from modules.vectorstore.store_manager import VectorStoreManager
4
  from modules.retriever.retriever import Retriever
5
  from modules.chat.langchain.langchain_rag import (
 
6
  Langchain_RAG_V2,
7
  QuestionGenerator,
8
  )
 
27
  self.rephrase_prompt = get_prompt(
28
  config, "rephrase"
29
  ) # Initialize rephrase_prompt
30
+
31
+ # TODO: Removed this functionality for now, don't know if we need it
32
+ # if self.config["vectorstore"]["embedd_files"]:
33
+ # self.vector_db.create_database()
34
+ # self.vector_db.save_database()
35
 
36
  def update_llm(self, old_config, new_config):
37
  """
 
49
  self.vector_db = VectorStoreManager(
50
  self.config, logger=self.logger
51
  ).load_database() # Reinitialize VectorStoreManager if vectorstore changes
52
+
53
+ # TODO: Removed this functionality for now, don't know if we need it
54
+ # if self.config["vectorstore"]["embedd_files"]:
55
+ # self.vector_db.create_database()
56
+ # self.vector_db.save_database()
57
 
58
  if "llm_params.llm_style" in changes:
59
  self.qa_prompt = get_prompt(
code/modules/chat_processor/helpers.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from literalai import AsyncLiteralClient
3
+ from datetime import datetime, timedelta, timezone
4
+ from modules.config.constants import COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME
5
+ from typing_extensions import TypedDict
6
+ import tiktoken
7
+ from typing import Any, Generic, List, Literal, Optional, TypeVar, Union
8
+
9
+ Field = TypeVar("Field")
10
+ Operators = TypeVar("Operators")
11
+ Value = TypeVar("Value")
12
+
13
+ BOOLEAN_OPERATORS = Literal["is", "nis"]
14
+ STRING_OPERATORS = Literal["eq", "neq", "ilike", "nilike"]
15
+ NUMBER_OPERATORS = Literal["eq", "neq", "gt", "gte", "lt", "lte"]
16
+ STRING_LIST_OPERATORS = Literal["in", "nin"]
17
+ DATETIME_OPERATORS = Literal["gte", "lte", "gt", "lt"]
18
+
19
+ OPERATORS = Union[
20
+ BOOLEAN_OPERATORS,
21
+ STRING_OPERATORS,
22
+ NUMBER_OPERATORS,
23
+ STRING_LIST_OPERATORS,
24
+ DATETIME_OPERATORS,
25
+ ]
26
+
27
+
28
+ class Filter(Generic[Field], TypedDict, total=False):
29
+ field: Field
30
+ operator: OPERATORS
31
+ value: Any
32
+ path: Optional[str]
33
+
34
+
35
+ class OrderBy(Generic[Field], TypedDict):
36
+ column: Field
37
+ direction: Literal["ASC", "DESC"]
38
+
39
+
40
+ threads_filterable_fields = Literal[
41
+ "id",
42
+ "createdAt",
43
+ "name",
44
+ "stepType",
45
+ "stepName",
46
+ "stepOutput",
47
+ "metadata",
48
+ "tokenCount",
49
+ "tags",
50
+ "participantId",
51
+ "participantIdentifiers",
52
+ "scoreValue",
53
+ "duration",
54
+ ]
55
+ threads_orderable_fields = Literal["createdAt", "tokenCount"]
56
+ threads_filters = List[Filter[threads_filterable_fields]]
57
+ threads_order_by = OrderBy[threads_orderable_fields]
58
+
59
+ steps_filterable_fields = Literal[
60
+ "id",
61
+ "name",
62
+ "input",
63
+ "output",
64
+ "participantIdentifier",
65
+ "startTime",
66
+ "endTime",
67
+ "metadata",
68
+ "parentId",
69
+ "threadId",
70
+ "error",
71
+ "tags",
72
+ ]
73
+ steps_orderable_fields = Literal["createdAt"]
74
+ steps_filters = List[Filter[steps_filterable_fields]]
75
+ steps_order_by = OrderBy[steps_orderable_fields]
76
+
77
+ users_filterable_fields = Literal[
78
+ "id",
79
+ "createdAt",
80
+ "identifier",
81
+ "lastEngaged",
82
+ "threadCount",
83
+ "tokenCount",
84
+ "metadata",
85
+ ]
86
+ users_filters = List[Filter[users_filterable_fields]]
87
+
88
+ scores_filterable_fields = Literal[
89
+ "id",
90
+ "createdAt",
91
+ "participant",
92
+ "name",
93
+ "tags",
94
+ "value",
95
+ "type",
96
+ "comment",
97
+ ]
98
+ scores_orderable_fields = Literal["createdAt"]
99
+ scores_filters = List[Filter[scores_filterable_fields]]
100
+ scores_order_by = OrderBy[scores_orderable_fields]
101
+
102
+ generation_filterable_fields = Literal[
103
+ "id",
104
+ "createdAt",
105
+ "model",
106
+ "duration",
107
+ "promptLineage",
108
+ "promptVersion",
109
+ "tags",
110
+ "score",
111
+ "participant",
112
+ "tokenCount",
113
+ "error",
114
+ ]
115
+ generation_orderable_fields = Literal[
116
+ "createdAt",
117
+ "tokenCount",
118
+ "model",
119
+ "provider",
120
+ "participant",
121
+ "duration",
122
+ ]
123
+ generations_filters = List[Filter[generation_filterable_fields]]
124
+ generations_order_by = OrderBy[generation_orderable_fields]
125
+
126
+ literal_client = AsyncLiteralClient(api_key=os.getenv("LITERAL_API_KEY_LOGGING"))
127
+
128
+
129
+ # For consistency, use dictionary for user_info
130
+ def convert_to_dict(user_info):
131
+ # if already a dictionary, return as is
132
+ if isinstance(user_info, dict):
133
+ return user_info
134
+ if hasattr(user_info, "__dict__"):
135
+ user_info = user_info.__dict__
136
+ return user_info
137
+
138
+
139
+ def get_time():
140
+ return datetime.now(timezone.utc).isoformat()
141
+
142
+
143
+ async def get_user_details(user_email_id):
144
+ user_info = await literal_client.api.get_or_create_user(identifier=user_email_id)
145
+ return user_info
146
+
147
+
148
+ async def update_user_info(user_info):
149
+ # if object type, convert to dictionary
150
+ user_info = convert_to_dict(user_info)
151
+ await literal_client.api.update_user(
152
+ id=user_info["id"],
153
+ identifier=user_info["identifier"],
154
+ metadata=user_info["metadata"],
155
+ )
156
+
157
+
158
+ async def check_user_cooldown(user_info, current_time):
159
+ # # Check if no tokens left
160
+ tokens_left = user_info.metadata.get("tokens_left", 0)
161
+ if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False):
162
+ return False, None
163
+
164
+ user_info = convert_to_dict(user_info)
165
+ last_message_time_str = user_info["metadata"].get("last_message_time")
166
+
167
+ # Convert from ISO format string to datetime object and ensure UTC timezone
168
+ last_message_time = datetime.fromisoformat(last_message_time_str).replace(
169
+ tzinfo=timezone.utc
170
+ )
171
+ current_time = datetime.fromisoformat(current_time).replace(tzinfo=timezone.utc)
172
+
173
+ # Calculate the elapsed time
174
+ elapsed_time = current_time - last_message_time
175
+ elapsed_time_in_seconds = elapsed_time.total_seconds()
176
+
177
+ # Calculate when the cooldown period ends
178
+ cooldown_end_time = last_message_time + timedelta(seconds=COOLDOWN_TIME)
179
+ cooldown_end_time_iso = cooldown_end_time.isoformat()
180
+
181
+ # Debug: Print the cooldown end time
182
+ print(f"Cooldown end time (ISO): {cooldown_end_time_iso}")
183
+
184
+ # Check if the user is still in cooldown
185
+ if elapsed_time_in_seconds < COOLDOWN_TIME:
186
+ return True, cooldown_end_time_iso # Return in ISO 8601 format
187
+
188
+ user_info["metadata"]["in_cooldown"] = False
189
+ # If not in cooldown, regenerate tokens
190
+ await reset_tokens_for_user(user_info)
191
+
192
+ return False, None
193
+
194
+
195
+ async def reset_tokens_for_user(user_info):
196
+ user_info = convert_to_dict(user_info)
197
+ last_message_time_str = user_info["metadata"].get("last_message_time")
198
+
199
+ last_message_time = datetime.fromisoformat(last_message_time_str).replace(
200
+ tzinfo=timezone.utc
201
+ )
202
+ current_time = datetime.fromisoformat(get_time()).replace(tzinfo=timezone.utc)
203
+
204
+ # Calculate the elapsed time since the last message
205
+ elapsed_time_in_seconds = (current_time - last_message_time).total_seconds()
206
+
207
+ # Current token count (can be negative)
208
+ current_tokens = user_info["metadata"].get("tokens_left_at_last_message", 0)
209
+ current_tokens = min(current_tokens, TOKENS_LEFT)
210
+
211
+ # Maximum tokens that can be regenerated
212
+ max_tokens = user_info["metadata"].get("max_tokens", TOKENS_LEFT)
213
+
214
+ # Calculate how many tokens should have been regenerated proportionally
215
+ if current_tokens < max_tokens:
216
+ # Calculate the regeneration rate per second based on REGEN_TIME for full regeneration
217
+ regeneration_rate_per_second = max_tokens / REGEN_TIME
218
+
219
+ # Calculate how many tokens should have been regenerated based on the elapsed time
220
+ tokens_to_regenerate = int(
221
+ elapsed_time_in_seconds * regeneration_rate_per_second
222
+ )
223
+
224
+ # Ensure the new token count does not exceed max_tokens
225
+ new_token_count = min(current_tokens + tokens_to_regenerate, max_tokens)
226
+
227
+ print(
228
+ f"\n\n Adding {tokens_to_regenerate} tokens to the user, Time elapsed: {elapsed_time_in_seconds} seconds, Tokens after regeneration: {new_token_count}, Tokens before: {current_tokens} \n\n"
229
+ )
230
+
231
+ # Update the user's token count
232
+ user_info["metadata"]["tokens_left"] = new_token_count
233
+
234
+ await update_user_info(user_info)
235
+
236
+
237
+ async def get_thread_step_info(thread_id):
238
+ step = await literal_client.api.get_step(thread_id)
239
+ return step
240
+
241
+
242
+ def get_num_tokens(text, model):
243
+ encoding = tiktoken.encoding_for_model(model)
244
+ tokens = encoding.encode(text)
245
+ return len(tokens)
code/modules/chat_processor/literal_ai.py CHANGED
@@ -1,44 +1,7 @@
1
- from chainlit.data import ChainlitDataLayer, queue_until_user_message
2
 
3
 
4
  # update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
5
  class CustomLiteralDataLayer(ChainlitDataLayer):
6
  def __init__(self, **kwargs):
7
  super().__init__(**kwargs)
8
-
9
- @queue_until_user_message()
10
- async def create_step(self, step_dict: "StepDict"):
11
- metadata = dict(
12
- step_dict.get("metadata", {}),
13
- **{
14
- "waitForAnswer": step_dict.get("waitForAnswer"),
15
- "language": step_dict.get("language"),
16
- "showInput": step_dict.get("showInput"),
17
- },
18
- )
19
-
20
- step: LiteralStepDict = {
21
- "createdAt": step_dict.get("createdAt"),
22
- "startTime": step_dict.get("start"),
23
- "endTime": step_dict.get("end"),
24
- "generation": step_dict.get("generation"),
25
- "id": step_dict.get("id"),
26
- "parentId": step_dict.get("parentId"),
27
- "name": step_dict.get("name"),
28
- "threadId": step_dict.get("threadId"),
29
- "type": step_dict.get("type"),
30
- "tags": step_dict.get("tags"),
31
- "metadata": metadata,
32
- }
33
- if step_dict.get("input"):
34
- step["input"] = {"content": step_dict.get("input")}
35
- if step_dict.get("output"):
36
- step["output"] = {"content": step_dict.get("output")}
37
- if step_dict.get("isError"):
38
- step["error"] = step_dict.get("output")
39
-
40
- # print("\n\n\n")
41
- # print("Step: ", step)
42
- # print("\n\n\n")
43
-
44
- await self.client.api.send_steps([step])
 
1
+ from chainlit.data import ChainlitDataLayer
2
 
3
 
4
  # update custom methods here (Ref: https://github.com/Chainlit/chainlit/blob/4b533cd53173bcc24abe4341a7108f0070d60099/backend/chainlit/data/__init__.py)
5
  class CustomLiteralDataLayer(ChainlitDataLayer):
6
  def __init__(self, **kwargs):
7
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/modules/config/config.yml CHANGED
@@ -4,7 +4,7 @@ device: 'cpu' # str [cuda, cpu]
4
 
5
  vectorstore:
6
  load_from_HF: True # bool
7
- embedd_files: False # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
10
  expand_urls: True # bool
@@ -37,14 +37,14 @@ llm_params:
37
  temperature: 0.7 # float
38
  repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
39
  filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
40
- pdf_reader: 'pymupdf' # str [llama, pymupdf, gpt]
41
  stream: False # bool
42
  pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
43
 
44
  chat_logging:
45
  log_chat: True # bool
46
  platform: 'literalai'
47
- callbacks: False # bool
48
 
49
  splitter_options:
50
  use_splitter: True # bool
 
4
 
5
  vectorstore:
6
  load_from_HF: True # bool
7
+ reparse_files: True # bool
8
  data_path: '../storage/data' # str
9
  url_file_path: '../storage/data/urls.txt' # str
10
  expand_urls: True # bool
 
37
  temperature: 0.7 # float
38
  repo_id: 'TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF' # HuggingFace repo id
39
  filename: 'tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Specific name of gguf file in the repo
40
+ model_path: 'storage/models/tinyllama-1.1b-chat-v1.0.Q5_0.gguf' # Path to the model file
41
  stream: False # bool
42
  pdf_reader: 'gpt' # str [llama, pymupdf, gpt]
43
 
44
  chat_logging:
45
  log_chat: True # bool
46
  platform: 'literalai'
47
+ callbacks: True # bool
48
 
49
  splitter_options:
50
  use_splitter: True # bool
code/modules/config/constants.py CHANGED
@@ -3,6 +3,15 @@ import os
3
 
4
  load_dotenv()
5
 
 
 
 
 
 
 
 
 
 
6
  # API Keys - Loaded from the .env file
7
 
8
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -10,14 +19,16 @@ LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
10
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
11
  LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
12
  LITERAL_API_URL = os.getenv("LITERAL_API_URL")
 
13
 
14
  OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
15
  OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
16
 
17
- opening_message = f"Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
 
 
 
18
 
19
  # Model Paths
20
 
21
  LLAMA_PATH = "../storage/models/tinyllama"
22
-
23
- RETRIEVER_HF_PATHS = {"RAGatouille": "XThomasBU/Colbert_Index"}
 
3
 
4
  load_dotenv()
5
 
6
+ TIMEOUT = 60
7
+ COOLDOWN_TIME = 60
8
+ REGEN_TIME = 180
9
+ TOKENS_LEFT = 2000
10
+ ALL_TIME_TOKENS_ALLOCATED = 1000000
11
+
12
+ GITHUB_REPO = "https://github.com/DL4DS/dl4ds_tutor"
13
+ DOCS_WEBSITE = "https://dl4ds.github.io/dl4ds_tutor/"
14
+
15
  # API Keys - Loaded from the .env file
16
 
17
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
19
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
20
  LITERAL_API_KEY_LOGGING = os.getenv("LITERAL_API_KEY_LOGGING")
21
  LITERAL_API_URL = os.getenv("LITERAL_API_URL")
22
+ CHAINLIT_URL = os.getenv("CHAINLIT_URL")
23
 
24
  OAUTH_GOOGLE_CLIENT_ID = os.getenv("OAUTH_GOOGLE_CLIENT_ID")
25
  OAUTH_GOOGLE_CLIENT_SECRET = os.getenv("OAUTH_GOOGLE_CLIENT_SECRET")
26
 
27
+ opening_message = "Hey, What Can I Help You With?\n\nYou can me ask me questions about the course logistics, course content, about the final project, or anything else!"
28
+ chat_end_message = (
29
+ "I hope I was able to help you. If you have any more questions, feel free to ask!"
30
+ )
31
 
32
  # Model Paths
33
 
34
  LLAMA_PATH = "../storage/models/tinyllama"
 
 
code/modules/config/project_config.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ retriever:
2
+ retriever_hf_paths:
3
+ RAGatouille: "XThomasBU/Colbert_Index"
4
+
5
+ metadata:
6
+ metadata_links: ["https://dl4ds.github.io/sp2024/lectures/", "https://dl4ds.github.io/sp2024/schedule/"]
7
+ slide_base_link: "https://dl4ds.github.io"
code/modules/dataloader/data_loader.py CHANGED
@@ -3,40 +3,26 @@ import re
3
  import requests
4
  import pysrt
5
  from langchain_community.document_loaders import (
6
- PyMuPDFLoader,
7
  Docx2txtLoader,
8
  YoutubeLoader,
9
- WebBaseLoader,
10
  TextLoader,
11
  )
12
- from langchain_community.document_loaders import UnstructuredMarkdownLoader
13
- from llama_parse import LlamaParse
14
  from langchain.schema import Document
15
  import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain_experimental.text_splitter import SemanticChunker
18
  from langchain_openai.embeddings import OpenAIEmbeddings
19
- from ragatouille import RAGPretrainedModel
20
- from langchain.chains import LLMChain
21
- from langchain_community.llms import OpenAI
22
- from langchain import PromptTemplate
23
  import json
24
  from concurrent.futures import ThreadPoolExecutor
25
  from urllib.parse import urljoin
26
  import html2text
27
  import bs4
28
- import tempfile
29
  import PyPDF2
30
  from modules.dataloader.pdf_readers.base import PDFReader
31
  from modules.dataloader.pdf_readers.llama import LlamaParser
32
  from modules.dataloader.pdf_readers.gpt import GPTParser
33
-
34
- try:
35
- from modules.dataloader.helpers import get_metadata, download_pdf_from_url
36
- from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
37
- except:
38
- from dataloader.helpers import get_metadata, download_pdf_from_url
39
- from config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
40
 
41
  logger = logging.getLogger(__name__)
42
  BASE_DIR = os.getcwd()
@@ -47,7 +33,7 @@ class HTMLReader:
47
  pass
48
 
49
  def read_url(self, url):
50
- response = requests.get(url)
51
  if response.status_code == 200:
52
  return response.text
53
  else:
@@ -65,11 +51,13 @@ class HTMLReader:
65
  href = href.replace("http", "https")
66
 
67
  absolute_url = urljoin(base_url, href)
68
- link['href'] = absolute_url
69
 
70
- resp = requests.head(absolute_url)
71
  if resp.status_code != 200:
72
- logger.warning(f"Link {absolute_url} is broken. Status code: {resp.status_code}")
 
 
73
 
74
  return str(soup)
75
 
@@ -85,6 +73,7 @@ class HTMLReader:
85
  else:
86
  return None
87
 
 
88
  class FileReader:
89
  def __init__(self, logger, kind):
90
  self.logger = logger
@@ -96,7 +85,9 @@ class FileReader:
96
  else:
97
  self.pdf_reader = PDFReader()
98
  self.web_reader = HTMLReader()
99
- self.logger.info(f"Initialized FileReader with {kind} PDF reader and HTML reader")
 
 
100
 
101
  def extract_text_from_pdf(self, pdf_path):
102
  text = ""
@@ -137,7 +128,7 @@ class FileReader:
137
  return [Document(page_content=self.web_reader.read_html(url))]
138
 
139
  def read_tex_from_url(self, tex_url):
140
- response = requests.get(tex_url)
141
  if response.status_code == 200:
142
  return [Document(page_content=response.text)]
143
  else:
@@ -154,17 +145,20 @@ class ChunkProcessor:
154
  self.document_metadata = {}
155
  self.document_chunks_full = []
156
 
157
- if not config['vectorstore']['embedd_files']:
 
158
  self.load_document_data()
159
 
160
  if config["splitter_options"]["use_splitter"]:
161
  if config["splitter_options"]["chunking_mode"] == "fixed":
162
  if config["splitter_options"]["split_by_token"]:
163
- self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
164
- chunk_size=config["splitter_options"]["chunk_size"],
165
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
166
- separators=config["splitter_options"]["chunk_separators"],
167
- disallowed_special=(),
 
 
168
  )
169
  else:
170
  self.splitter = RecursiveCharacterTextSplitter(
@@ -175,8 +169,7 @@ class ChunkProcessor:
175
  )
176
  else:
177
  self.splitter = SemanticChunker(
178
- OpenAIEmbeddings(),
179
- breakpoint_threshold_type="percentile"
180
  )
181
 
182
  else:
@@ -203,7 +196,10 @@ class ChunkProcessor:
203
  ):
204
  # TODO: Clear up this pipeline of re-adding metadata
205
  documents = [Document(page_content=documents, source=source, page=page)]
206
- if file_type == "pdf" and self.config["splitter_options"]["chunking_mode"] == "fixed":
 
 
 
207
  document_chunks = documents
208
  else:
209
  document_chunks = self.splitter.split_documents(documents)
@@ -226,9 +222,22 @@ class ChunkProcessor:
226
 
227
  def chunk_docs(self, file_reader, uploaded_files, weblinks):
228
  addl_metadata = get_metadata(
229
- "https://dl4ds.github.io/sp2024/lectures/",
230
- "https://dl4ds.github.io/sp2024/schedule/",
231
  ) # For any additional metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  with ThreadPoolExecutor() as executor:
233
  executor.map(
234
  self.process_file,
@@ -298,6 +307,7 @@ class ChunkProcessor:
298
  self.document_metadata[file_path] = file_metadata
299
 
300
  def process_file(self, file_path, file_index, file_reader, addl_metadata):
 
301
  file_name = os.path.basename(file_path)
302
 
303
  file_type = file_name.split(".")[-1]
@@ -314,10 +324,12 @@ class ChunkProcessor:
314
  return
315
 
316
  try:
317
-
318
  if file_path in self.document_data:
319
  self.logger.warning(f"File {file_name} already processed")
320
- documents = [Document(page_content=content) for content in self.document_data[file_path].values()]
 
 
 
321
  else:
322
  documents = read_methods[file_type](file_path)
323
 
@@ -370,22 +382,31 @@ class ChunkProcessor:
370
  json.dump(self.document_metadata, json_file, indent=4)
371
 
372
  def load_document_data(self):
373
- with open(
374
- f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
375
- ) as json_file:
376
- self.document_data = json.load(json_file)
377
- with open(
378
- f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
379
- ) as json_file:
380
- self.document_metadata = json.load(json_file)
381
- self.logger.info(
382
- f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}"
383
- )
 
 
 
 
 
 
 
384
 
385
 
386
  class DataLoader:
387
  def __init__(self, config, logger=None):
388
- self.file_reader = FileReader(logger=logger, kind=config["llm_params"]["pdf_reader"])
 
 
389
  self.chunk_processor = ChunkProcessor(config, logger=logger)
390
 
391
  def get_chunks(self, uploaded_files, weblinks):
@@ -396,6 +417,15 @@ class DataLoader:
396
 
397
  if __name__ == "__main__":
398
  import yaml
 
 
 
 
 
 
 
 
 
399
 
400
  logger = logging.getLogger(__name__)
401
  logger.setLevel(logging.INFO)
@@ -403,19 +433,30 @@ if __name__ == "__main__":
403
  with open("../code/modules/config/config.yml", "r") as f:
404
  config = yaml.safe_load(f)
405
 
406
- STORAGE_DIR = os.path.join(BASE_DIR, config['vectorstore']["data_path"])
 
 
 
 
 
 
407
  uploaded_files = [
408
- os.path.join(STORAGE_DIR, file) for file in os.listdir(STORAGE_DIR) if file != "urls.txt"
 
 
409
  ]
410
 
411
  data_loader = DataLoader(config, logger=logger)
412
- document_chunks, document_names, documents, document_metadata = (
413
- data_loader.get_chunks(
414
- ["https://dl4ds.github.io/sp2024/static_files/lectures/05_loss_functions_v2.pdf"],
415
- [],
416
- )
 
 
 
 
417
  )
418
 
419
  print(document_names[:5])
420
  print(len(document_chunks))
421
-
 
3
  import requests
4
  import pysrt
5
  from langchain_community.document_loaders import (
 
6
  Docx2txtLoader,
7
  YoutubeLoader,
 
8
  TextLoader,
9
  )
 
 
10
  from langchain.schema import Document
11
  import logging
12
  from langchain.text_splitter import RecursiveCharacterTextSplitter
13
  from langchain_experimental.text_splitter import SemanticChunker
14
  from langchain_openai.embeddings import OpenAIEmbeddings
 
 
 
 
15
  import json
16
  from concurrent.futures import ThreadPoolExecutor
17
  from urllib.parse import urljoin
18
  import html2text
19
  import bs4
 
20
  import PyPDF2
21
  from modules.dataloader.pdf_readers.base import PDFReader
22
  from modules.dataloader.pdf_readers.llama import LlamaParser
23
  from modules.dataloader.pdf_readers.gpt import GPTParser
24
+ from modules.dataloader.helpers import get_metadata
25
+ from modules.config.constants import TIMEOUT
 
 
 
 
 
26
 
27
  logger = logging.getLogger(__name__)
28
  BASE_DIR = os.getcwd()
 
33
  pass
34
 
35
  def read_url(self, url):
36
+ response = requests.get(url, timeout=TIMEOUT)
37
  if response.status_code == 200:
38
  return response.text
39
  else:
 
51
  href = href.replace("http", "https")
52
 
53
  absolute_url = urljoin(base_url, href)
54
+ link["href"] = absolute_url
55
 
56
+ resp = requests.head(absolute_url, timeout=TIMEOUT)
57
  if resp.status_code != 200:
58
+ logger.warning(
59
+ f"Link {absolute_url} is broken. Status code: {resp.status_code}"
60
+ )
61
 
62
  return str(soup)
63
 
 
73
  else:
74
  return None
75
 
76
+
77
  class FileReader:
78
  def __init__(self, logger, kind):
79
  self.logger = logger
 
85
  else:
86
  self.pdf_reader = PDFReader()
87
  self.web_reader = HTMLReader()
88
+ self.logger.info(
89
+ f"Initialized FileReader with {kind} PDF reader and HTML reader"
90
+ )
91
 
92
  def extract_text_from_pdf(self, pdf_path):
93
  text = ""
 
128
  return [Document(page_content=self.web_reader.read_html(url))]
129
 
130
  def read_tex_from_url(self, tex_url):
131
+ response = requests.get(tex_url, timeout=TIMEOUT)
132
  if response.status_code == 200:
133
  return [Document(page_content=response.text)]
134
  else:
 
145
  self.document_metadata = {}
146
  self.document_chunks_full = []
147
 
148
+ # TODO: Fix when reparse_files is False
149
+ if not config["vectorstore"]["reparse_files"]:
150
  self.load_document_data()
151
 
152
  if config["splitter_options"]["use_splitter"]:
153
  if config["splitter_options"]["chunking_mode"] == "fixed":
154
  if config["splitter_options"]["split_by_token"]:
155
+ self.splitter = (
156
+ RecursiveCharacterTextSplitter.from_tiktoken_encoder(
157
+ chunk_size=config["splitter_options"]["chunk_size"],
158
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
159
+ separators=config["splitter_options"]["chunk_separators"],
160
+ disallowed_special=(),
161
+ )
162
  )
163
  else:
164
  self.splitter = RecursiveCharacterTextSplitter(
 
169
  )
170
  else:
171
  self.splitter = SemanticChunker(
172
+ OpenAIEmbeddings(), breakpoint_threshold_type="percentile"
 
173
  )
174
 
175
  else:
 
196
  ):
197
  # TODO: Clear up this pipeline of re-adding metadata
198
  documents = [Document(page_content=documents, source=source, page=page)]
199
+ if (
200
+ file_type == "pdf"
201
+ and self.config["splitter_options"]["chunking_mode"] == "fixed"
202
+ ):
203
  document_chunks = documents
204
  else:
205
  document_chunks = self.splitter.split_documents(documents)
 
222
 
223
  def chunk_docs(self, file_reader, uploaded_files, weblinks):
224
  addl_metadata = get_metadata(
225
+ *self.config["metadata"]["metadata_links"], self.config
 
226
  ) # For any additional metadata
227
+
228
+ # remove already processed files if reparse_files is False
229
+ if not self.config["vectorstore"]["reparse_files"]:
230
+ total_documents = len(uploaded_files) + len(weblinks)
231
+ uploaded_files = [
232
+ file_path
233
+ for file_path in uploaded_files
234
+ if file_path not in self.document_data
235
+ ]
236
+ weblinks = [link for link in weblinks if link not in self.document_data]
237
+ print(
238
+ f"Total documents to process: {total_documents}, Documents already processed: {total_documents - len(uploaded_files) - len(weblinks)}"
239
+ )
240
+
241
  with ThreadPoolExecutor() as executor:
242
  executor.map(
243
  self.process_file,
 
307
  self.document_metadata[file_path] = file_metadata
308
 
309
  def process_file(self, file_path, file_index, file_reader, addl_metadata):
310
+ print(f"Processing file {file_index + 1} : {file_path}")
311
  file_name = os.path.basename(file_path)
312
 
313
  file_type = file_name.split(".")[-1]
 
324
  return
325
 
326
  try:
 
327
  if file_path in self.document_data:
328
  self.logger.warning(f"File {file_name} already processed")
329
+ documents = [
330
+ Document(page_content=content)
331
+ for content in self.document_data[file_path].values()
332
+ ]
333
  else:
334
  documents = read_methods[file_type](file_path)
335
 
 
382
  json.dump(self.document_metadata, json_file, indent=4)
383
 
384
  def load_document_data(self):
385
+ try:
386
+ with open(
387
+ f"{self.config['log_chunk_dir']}/docs/doc_content.json", "r"
388
+ ) as json_file:
389
+ self.document_data = json.load(json_file)
390
+ with open(
391
+ f"{self.config['log_chunk_dir']}/metadata/doc_metadata.json", "r"
392
+ ) as json_file:
393
+ self.document_metadata = json.load(json_file)
394
+ self.logger.info(
395
+ f"Loaded document content from {self.config['log_chunk_dir']}/docs/doc_content.json. Total documents: {len(self.document_data)}"
396
+ )
397
+ except FileNotFoundError:
398
+ self.logger.warning(
399
+ f"Document content not found in {self.config['log_chunk_dir']}/docs/doc_content.json"
400
+ )
401
+ self.document_data = {}
402
+ self.document_metadata = {}
403
 
404
 
405
  class DataLoader:
406
  def __init__(self, config, logger=None):
407
+ self.file_reader = FileReader(
408
+ logger=logger, kind=config["llm_params"]["pdf_reader"]
409
+ )
410
  self.chunk_processor = ChunkProcessor(config, logger=logger)
411
 
412
  def get_chunks(self, uploaded_files, weblinks):
 
417
 
418
  if __name__ == "__main__":
419
  import yaml
420
+ import argparse
421
+
422
+ parser = argparse.ArgumentParser(description="Process some links.")
423
+ parser.add_argument(
424
+ "--links", nargs="+", required=True, help="List of links to process."
425
+ )
426
+
427
+ args = parser.parse_args()
428
+ links_to_process = args.links
429
 
430
  logger = logging.getLogger(__name__)
431
  logger.setLevel(logging.INFO)
 
433
  with open("../code/modules/config/config.yml", "r") as f:
434
  config = yaml.safe_load(f)
435
 
436
+ with open("../code/modules/config/project_config.yml", "r") as f:
437
+ project_config = yaml.safe_load(f)
438
+
439
+ # Combine project config with the main config
440
+ config.update(project_config)
441
+
442
+ STORAGE_DIR = os.path.join(BASE_DIR, config["vectorstore"]["data_path"])
443
  uploaded_files = [
444
+ os.path.join(STORAGE_DIR, file)
445
+ for file in os.listdir(STORAGE_DIR)
446
+ if file != "urls.txt"
447
  ]
448
 
449
  data_loader = DataLoader(config, logger=logger)
450
+ # Just for testing
451
+ (
452
+ document_chunks,
453
+ document_names,
454
+ documents,
455
+ document_metadata,
456
+ ) = data_loader.get_chunks(
457
+ links_to_process,
458
+ [],
459
  )
460
 
461
  print(document_names[:5])
462
  print(len(document_chunks))
 
code/modules/dataloader/helpers.py CHANGED
@@ -2,6 +2,8 @@ import requests
2
  from bs4 import BeautifulSoup
3
  from urllib.parse import urlparse
4
  import tempfile
 
 
5
 
6
  def get_urls_from_file(file_path: str):
7
  """
@@ -19,18 +21,19 @@ def get_base_url(url):
19
  return base_url
20
 
21
 
22
- def get_metadata(lectures_url, schedule_url):
 
23
  """
24
  Function to get the lecture metadata from the lectures and schedule URLs.
25
  """
26
  lecture_metadata = {}
27
 
28
  # Get the main lectures page content
29
- r_lectures = requests.get(lectures_url)
30
  soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
31
 
32
  # Get the main schedule page content
33
- r_schedule = requests.get(schedule_url)
34
  soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
35
 
36
  # Find all lecture blocks
@@ -48,7 +51,9 @@ def get_metadata(lectures_url, schedule_url):
48
  slides_link_tag = description_div.find("a", title="Download slides")
49
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
50
  slides_link = (
51
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
 
 
52
  )
53
  if slides_link:
54
  date_mapping[slides_link] = date
@@ -68,7 +73,9 @@ def get_metadata(lectures_url, schedule_url):
68
  slides_link_tag = block.find("a", title="Download slides")
69
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
70
  slides_link = (
71
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
 
 
72
  )
73
 
74
  # Extract the link to the lecture recording
@@ -118,7 +125,7 @@ def download_pdf_from_url(pdf_url):
118
  Returns:
119
  str: The local file path of the downloaded PDF file.
120
  """
121
- response = requests.get(pdf_url)
122
  if response.status_code == 200:
123
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
124
  temp_file.write(response.content)
 
2
  from bs4 import BeautifulSoup
3
  from urllib.parse import urlparse
4
  import tempfile
5
+ from modules.config.constants import TIMEOUT
6
+
7
 
8
  def get_urls_from_file(file_path: str):
9
  """
 
21
  return base_url
22
 
23
 
24
+ ### THIS FUNCTION IS NOT GENERALIZABLE.. IT IS SPECIFIC TO THE COURSE WEBSITE ###
25
+ def get_metadata(lectures_url, schedule_url, config):
26
  """
27
  Function to get the lecture metadata from the lectures and schedule URLs.
28
  """
29
  lecture_metadata = {}
30
 
31
  # Get the main lectures page content
32
+ r_lectures = requests.get(lectures_url, timeout=TIMEOUT)
33
  soup_lectures = BeautifulSoup(r_lectures.text, "html.parser")
34
 
35
  # Get the main schedule page content
36
+ r_schedule = requests.get(schedule_url, timeout=TIMEOUT)
37
  soup_schedule = BeautifulSoup(r_schedule.text, "html.parser")
38
 
39
  # Find all lecture blocks
 
51
  slides_link_tag = description_div.find("a", title="Download slides")
52
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
53
  slides_link = (
54
+ f"{config['metadata']['slide_base_link']}{slides_link}"
55
+ if slides_link
56
+ else None
57
  )
58
  if slides_link:
59
  date_mapping[slides_link] = date
 
73
  slides_link_tag = block.find("a", title="Download slides")
74
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
75
  slides_link = (
76
+ f"{config['metadata']['slide_base_link']}{slides_link}"
77
+ if slides_link
78
+ else None
79
  )
80
 
81
  # Extract the link to the lecture recording
 
125
  Returns:
126
  str: The local file path of the downloaded PDF file.
127
  """
128
+ response = requests.get(pdf_url, timeout=TIMEOUT)
129
  if response.status_code == 200:
130
  with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
131
  temp_file.write(response.content)
code/modules/dataloader/pdf_readers/gpt.py CHANGED
@@ -6,6 +6,7 @@ from io import BytesIO
6
  from openai import OpenAI
7
  from pdf2image import convert_from_path
8
  from langchain.schema import Document
 
9
 
10
 
11
  class GPTParser:
@@ -19,9 +20,9 @@ class GPTParser:
19
  self.api_key = os.getenv("OPENAI_API_KEY")
20
  self.prompt = """
21
  The provided documents are images of PDFs of lecture slides of deep learning material.
22
- They contain LaTeX equations, images, and text.
23
  The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
24
- The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
25
  For images, give a description and if you can, a source. Separate each page with '---'.
26
  Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
27
  """
@@ -31,36 +32,45 @@ class GPTParser:
31
 
32
  encoded_images = [self.encode_image(image) for image in images]
33
 
34
- chunks = [encoded_images[i:i + 5] for i in range(0, len(encoded_images), 5)]
35
 
36
  headers = {
37
  "Content-Type": "application/json",
38
- "Authorization": f"Bearer {self.api_key}"
39
  }
40
 
41
  output = ""
42
  for chunk_num, chunk in enumerate(chunks):
43
- content = [{"type": "image_url", "image_url": {
44
- "url": f"data:image/jpeg;base64,{image}"}} for image in chunk]
 
 
 
 
 
45
 
46
  content.insert(0, {"type": "text", "text": self.prompt})
47
 
48
  payload = {
49
  "model": "gpt-4o-mini",
50
- "messages": [
51
- {
52
- "role": "user",
53
- "content": content
54
- }
55
- ],
56
  }
57
 
58
  response = requests.post(
59
- "https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
 
 
 
 
60
 
61
  resp = response.json()
62
 
63
- chunk_output = resp['choices'][0]['message']['content'].replace("```", "").replace("markdown", "").replace("````", "")
 
 
 
 
 
64
 
65
  output += chunk_output + "\n---\n"
66
 
@@ -68,14 +78,12 @@ class GPTParser:
68
  output = [doc for doc in output if doc.strip() != ""]
69
 
70
  documents = [
71
- Document(
72
- page_content=page,
73
- metadata={"source": pdf_path, "page": i}
74
- ) for i, page in enumerate(output)
75
  ]
76
  return documents
77
 
78
  def encode_image(self, image):
79
  buffered = BytesIO()
80
  image.save(buffered, format="JPEG")
81
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
6
  from openai import OpenAI
7
  from pdf2image import convert_from_path
8
  from langchain.schema import Document
9
+ from modules.config.constants import TIMEOUT
10
 
11
 
12
  class GPTParser:
 
20
  self.api_key = os.getenv("OPENAI_API_KEY")
21
  self.prompt = """
22
  The provided documents are images of PDFs of lecture slides of deep learning material.
23
+ They contain LaTeX equations, images, and text.
24
  The goal is to extract the text, images and equations from the slides and convert everything to markdown format. Some of the equations may be complicated.
25
+ The markdown should be clean and easy to read, and any math equation should be converted to LaTeX, between $$.
26
  For images, give a description and if you can, a source. Separate each page with '---'.
27
  Just respond with the markdown. Do not include page numbers or any other metadata. Do not try to provide titles. Strictly the content.
28
  """
 
32
 
33
  encoded_images = [self.encode_image(image) for image in images]
34
 
35
+ chunks = [encoded_images[i : i + 5] for i in range(0, len(encoded_images), 5)]
36
 
37
  headers = {
38
  "Content-Type": "application/json",
39
+ "Authorization": f"Bearer {self.api_key}",
40
  }
41
 
42
  output = ""
43
  for chunk_num, chunk in enumerate(chunks):
44
+ content = [
45
+ {
46
+ "type": "image_url",
47
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
48
+ }
49
+ for image in chunk
50
+ ]
51
 
52
  content.insert(0, {"type": "text", "text": self.prompt})
53
 
54
  payload = {
55
  "model": "gpt-4o-mini",
56
+ "messages": [{"role": "user", "content": content}],
 
 
 
 
 
57
  }
58
 
59
  response = requests.post(
60
+ "https://api.openai.com/v1/chat/completions",
61
+ headers=headers,
62
+ json=payload,
63
+ timeout=TIMEOUT,
64
+ )
65
 
66
  resp = response.json()
67
 
68
+ chunk_output = (
69
+ resp["choices"][0]["message"]["content"]
70
+ .replace("```", "")
71
+ .replace("markdown", "")
72
+ .replace("````", "")
73
+ )
74
 
75
  output += chunk_output + "\n---\n"
76
 
 
78
  output = [doc for doc in output if doc.strip() != ""]
79
 
80
  documents = [
81
+ Document(page_content=page, metadata={"source": pdf_path, "page": i})
82
+ for i, page in enumerate(output)
 
 
83
  ]
84
  return documents
85
 
86
  def encode_image(self, image):
87
  buffered = BytesIO()
88
  image.save(buffered, format="JPEG")
89
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
code/modules/dataloader/pdf_readers/llama.py CHANGED
@@ -2,19 +2,18 @@ import os
2
  import requests
3
  from llama_parse import LlamaParse
4
  from langchain.schema import Document
5
- from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY
6
  from modules.dataloader.helpers import download_pdf_from_url
7
 
8
 
9
-
10
  class LlamaParser:
11
  def __init__(self):
12
  self.GPT_API_KEY = OPENAI_API_KEY
13
  self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
14
  self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
15
  self.headers = {
16
- 'Accept': 'application/json',
17
- 'Authorization': f'Bearer {LLAMA_CLOUD_API_KEY}'
18
  }
19
  self.parser = LlamaParse(
20
  api_key=LLAMA_CLOUD_API_KEY,
@@ -23,7 +22,7 @@ class LlamaParser:
23
  language="en",
24
  gpt4o_mode=False,
25
  # gpt4o_api_key=OPENAI_API_KEY,
26
- parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source."
27
  )
28
 
29
  def parse(self, pdf_path):
@@ -38,10 +37,8 @@ class LlamaParser:
38
  pages = [page.strip() for page in pages]
39
 
40
  documents = [
41
- Document(
42
- page_content=page,
43
- metadata={"source": pdf_path, "page": i}
44
- ) for i, page in enumerate(pages)
45
  ]
46
 
47
  return documents
@@ -53,20 +50,30 @@ class LlamaParser:
53
  }
54
 
55
  files = [
56
- ('file', ('file', requests.get(pdf_url).content, 'application/octet-stream'))
 
 
 
 
 
 
 
57
  ]
58
 
59
  response = requests.request(
60
- "POST", self.parse_url, headers=self.headers, data=payload, files=files)
 
61
 
62
- return response.json()['id'], response.json()['status']
63
 
64
  async def get_result(self, job_id):
65
- url = f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}/result/markdown"
 
 
66
 
67
  response = requests.request("GET", url, headers=self.headers, data={})
68
 
69
- return response.json()['markdown']
70
 
71
  async def _parse(self, pdf_path):
72
  job_id, status = self.make_request(pdf_path)
@@ -78,15 +85,9 @@ class LlamaParser:
78
 
79
  result = await self.get_result(job_id)
80
 
81
- documents = [
82
- Document(
83
- page_content=result,
84
- metadata={"source": pdf_path}
85
- )
86
- ]
87
 
88
  return documents
89
 
90
- async def _parse(self, pdf_path):
91
- return await self._parse(pdf_path)
92
-
 
2
  import requests
3
  from llama_parse import LlamaParse
4
  from langchain.schema import Document
5
+ from modules.config.constants import OPENAI_API_KEY, LLAMA_CLOUD_API_KEY, TIMEOUT
6
  from modules.dataloader.helpers import download_pdf_from_url
7
 
8
 
 
9
  class LlamaParser:
10
  def __init__(self):
11
  self.GPT_API_KEY = OPENAI_API_KEY
12
  self.LLAMA_CLOUD_API_KEY = LLAMA_CLOUD_API_KEY
13
  self.parse_url = "https://api.cloud.llamaindex.ai/api/parsing/upload"
14
  self.headers = {
15
+ "Accept": "application/json",
16
+ "Authorization": f"Bearer {LLAMA_CLOUD_API_KEY}",
17
  }
18
  self.parser = LlamaParse(
19
  api_key=LLAMA_CLOUD_API_KEY,
 
22
  language="en",
23
  gpt4o_mode=False,
24
  # gpt4o_api_key=OPENAI_API_KEY,
25
+ parsing_instruction="The provided documents are PDFs of lecture slides of deep learning material. They contain LaTeX equations, images, and text. The goal is to extract the text, images and equations from the slides. The markdown should be clean and easy to read, and any math equation should be converted to LaTeX format, between $ signs. For images, if you can, give a description and a source.",
26
  )
27
 
28
  def parse(self, pdf_path):
 
37
  pages = [page.strip() for page in pages]
38
 
39
  documents = [
40
+ Document(page_content=page, metadata={"source": pdf_path, "page": i})
41
+ for i, page in enumerate(pages)
 
 
42
  ]
43
 
44
  return documents
 
50
  }
51
 
52
  files = [
53
+ (
54
+ "file",
55
+ (
56
+ "file",
57
+ requests.get(pdf_url, timeout=TIMEOUT).content,
58
+ "application/octet-stream",
59
+ ),
60
+ )
61
  ]
62
 
63
  response = requests.request(
64
+ "POST", self.parse_url, headers=self.headers, data=payload, files=files
65
+ )
66
 
67
+ return response.json()["id"], response.json()["status"]
68
 
69
  async def get_result(self, job_id):
70
+ url = (
71
+ f"https://api.cloud.llamaindex.ai/api/parsing/job/{job_id}/result/markdown"
72
+ )
73
 
74
  response = requests.request("GET", url, headers=self.headers, data={})
75
 
76
+ return response.json()["markdown"]
77
 
78
  async def _parse(self, pdf_path):
79
  job_id, status = self.make_request(pdf_path)
 
85
 
86
  result = await self.get_result(job_id)
87
 
88
+ documents = [Document(page_content=result, metadata={"source": pdf_path})]
 
 
 
 
 
89
 
90
  return documents
91
 
92
+ # async def _parse(self, pdf_path):
93
+ # return await self._parse(pdf_path)
 
code/modules/dataloader/webpage_crawler.py CHANGED
@@ -3,7 +3,9 @@ from aiohttp import ClientSession
3
  import asyncio
4
  import requests
5
  from bs4 import BeautifulSoup
6
- from urllib.parse import urlparse, urljoin, urldefrag
 
 
7
 
8
  class WebpageCrawler:
9
  def __init__(self):
@@ -18,7 +20,7 @@ class WebpageCrawler:
18
 
19
  def url_exists(self, url: str) -> bool:
20
  try:
21
- response = requests.head(url)
22
  return response.status_code == 200
23
  except requests.ConnectionError:
24
  return False
@@ -88,7 +90,7 @@ class WebpageCrawler:
88
 
89
  def is_webpage(self, url: str) -> bool:
90
  try:
91
- response = requests.head(url, allow_redirects=True)
92
  content_type = response.headers.get("Content-Type", "").lower()
93
  return "text/html" in content_type
94
  except requests.RequestException:
 
3
  import asyncio
4
  import requests
5
  from bs4 import BeautifulSoup
6
+ from urllib.parse import urljoin, urldefrag
7
+ from modules.config.constants import TIMEOUT
8
+
9
 
10
  class WebpageCrawler:
11
  def __init__(self):
 
20
 
21
  def url_exists(self, url: str) -> bool:
22
  try:
23
+ response = requests.head(url, timeout=TIMEOUT)
24
  return response.status_code == 200
25
  except requests.ConnectionError:
26
  return False
 
90
 
91
  def is_webpage(self, url: str) -> bool:
92
  try:
93
+ response = requests.head(url, allow_redirects=True, timeout=TIMEOUT)
94
  content_type = response.headers.get("Content-Type", "").lower()
95
  return "text/html" in content_type
96
  except requests.RequestException:
code/modules/retriever/helpers.py CHANGED
@@ -6,7 +6,6 @@ from typing import List
6
 
7
 
8
  class VectorStoreRetrieverScore(VectorStoreRetriever):
9
-
10
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
11
  def _get_relevant_documents(
12
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
 
6
 
7
 
8
  class VectorStoreRetrieverScore(VectorStoreRetriever):
 
9
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
10
  def _get_relevant_documents(
11
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
code/modules/vectorstore/colbert.py CHANGED
@@ -1,9 +1,9 @@
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
3
  from langchain_core.retrievers import BaseRetriever
4
- from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun, Callbacks
5
  from langchain_core.documents import Document
6
- from typing import Any, List, Optional, Sequence
7
  import os
8
  import json
9
 
@@ -85,6 +85,7 @@ class ColbertVectorStore(VectorStoreBase):
85
  document_ids=document_names,
86
  document_metadatas=document_metadata,
87
  )
 
88
  self.colbert.set_document_count(len(document_names))
89
 
90
  def load_database(self):
 
1
  from ragatouille import RAGPretrainedModel
2
  from modules.vectorstore.base import VectorStoreBase
3
  from langchain_core.retrievers import BaseRetriever
4
+ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
5
  from langchain_core.documents import Document
6
+ from typing import Any, List
7
  import os
8
  import json
9
 
 
85
  document_ids=document_names,
86
  document_metadatas=document_metadata,
87
  )
88
+ print(f"Index created at {index_path}")
89
  self.colbert.set_document_count(len(document_names))
90
 
91
  def load_database(self):
code/modules/vectorstore/embedding_model_loader.py CHANGED
@@ -1,9 +1,6 @@
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
- from langchain_community.embeddings import LlamaCppEmbeddings
4
-
5
- from modules.config.constants import *
6
- import os
7
 
8
 
9
  class EmbeddingModelLoader:
@@ -28,8 +25,5 @@ class EmbeddingModelLoader:
28
  "trust_remote_code": True,
29
  },
30
  )
31
- # embedding_model = LlamaCppEmbeddings(
32
- # model_path=os.path.abspath("storage/llama-7b.ggmlv3.q4_0.bin")
33
- # )
34
 
35
  return embedding_model
 
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from modules.config.constants import OPENAI_API_KEY, HUGGINGFACE_TOKEN
 
 
 
4
 
5
 
6
  class EmbeddingModelLoader:
 
25
  "trust_remote_code": True,
26
  },
27
  )
 
 
 
28
 
29
  return embedding_model
code/modules/vectorstore/faiss.py CHANGED
@@ -14,10 +14,15 @@ class FaissVectorStore(VectorStoreBase):
14
  def __init__(self, config):
15
  self.config = config
16
  self._init_vector_db()
17
- self.local_path = os.path.join(self.config["vectorstore"]["db_path"],
18
- "db_" + self.config["vectorstore"]["db_option"]
19
- + "_" + self.config["vectorstore"]["model"]
20
- + "_" + config["splitter_options"]["chunking_mode"])
 
 
 
 
 
21
 
22
  def _init_vector_db(self):
23
  self.faiss = FAISS(
@@ -28,9 +33,7 @@ class FaissVectorStore(VectorStoreBase):
28
  self.vectorstore = self.faiss.from_documents(
29
  documents=document_chunks, embedding=embedding_model
30
  )
31
- self.vectorstore.save_local(
32
- self.local_path
33
- )
34
 
35
  def load_database(self, embedding_model):
36
  self.vectorstore = self.faiss.load_local(
 
14
  def __init__(self, config):
15
  self.config = config
16
  self._init_vector_db()
17
+ self.local_path = os.path.join(
18
+ self.config["vectorstore"]["db_path"],
19
+ "db_"
20
+ + self.config["vectorstore"]["db_option"]
21
+ + "_"
22
+ + self.config["vectorstore"]["model"]
23
+ + "_"
24
+ + config["splitter_options"]["chunking_mode"],
25
+ )
26
 
27
  def _init_vector_db(self):
28
  self.faiss = FAISS(
 
33
  self.vectorstore = self.faiss.from_documents(
34
  documents=document_chunks, embedding=embedding_model
35
  )
36
+ self.vectorstore.save_local(self.local_path)
 
 
37
 
38
  def load_database(self, embedding_model):
39
  self.vectorstore = self.faiss.load_local(
code/modules/vectorstore/raptor.py CHANGED
@@ -317,13 +317,10 @@ class RAPTORVectoreStore(VectorStoreBase):
317
  print(f"--Generated {len(all_clusters)} clusters--")
318
 
319
  # Summarization
320
- template = """Here is content from the course DS598: Deep Learning for Data Science.
321
-
322
  The content may be form webapge about the course, or lecture content, or any other relevant information.
323
  If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
324
-
325
  Give a detailed summary of the content below.
326
-
327
  Documentation:
328
  {context}
329
  """
 
317
  print(f"--Generated {len(all_clusters)} clusters--")
318
 
319
  # Summarization
320
+ template = """Here is content from the course DS598: Deep Learning for Data Science.
 
321
  The content may be form webapge about the course, or lecture content, or any other relevant information.
322
  If the content is in bullet points (from pdf lectre slides), you can summarize the bullet points.
 
323
  Give a detailed summary of the content below.
 
324
  Documentation:
325
  {context}
326
  """
code/modules/vectorstore/store_manager.py CHANGED
@@ -1,9 +1,7 @@
1
  from modules.vectorstore.vectorstore import VectorStore
2
- from modules.vectorstore.helpers import *
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
5
- from modules.dataloader.helpers import *
6
- from modules.config.constants import RETRIEVER_HF_PATHS
7
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
8
  import logging
9
  import os
@@ -49,7 +47,6 @@ class VectorStoreManager:
49
  return logger
50
 
51
  def load_files(self):
52
-
53
  files = os.listdir(self.config["vectorstore"]["data_path"])
54
  files = [
55
  os.path.join(self.config["vectorstore"]["data_path"], file)
@@ -71,7 +68,6 @@ class VectorStoreManager:
71
  return files, urls
72
 
73
  def create_embedding_model(self):
74
-
75
  self.logger.info("Creating embedding function")
76
  embedding_model_loader = EmbeddingModelLoader(self.config)
77
  embedding_model = embedding_model_loader.load_embedding_model()
@@ -102,7 +98,6 @@ class VectorStoreManager:
102
  )
103
 
104
  def create_database(self):
105
-
106
  start_time = time.time() # Start time for creating database
107
  data_loader = DataLoader(self.config, self.logger)
108
  self.logger.info("Loading data")
@@ -112,12 +107,15 @@ class VectorStoreManager:
112
  self.logger.info(f"Number of webpages: {len(webpages)}")
113
  if f"{self.config['vectorstore']['url_file_path']}" in files:
114
  files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
115
- document_chunks, document_names, documents, document_metadata = (
116
- data_loader.get_chunks(files, webpages)
117
- )
 
 
 
118
  num_documents = len(document_chunks)
119
  self.logger.info(f"Number of documents in the DB: {num_documents}")
120
- metadata_keys = list(document_metadata[0].keys())
121
  self.logger.info(f"Metadata keys: {metadata_keys}")
122
  self.logger.info("Completed loading data")
123
  self.initialize_database(
@@ -130,7 +128,6 @@ class VectorStoreManager:
130
  )
131
 
132
  def load_database(self):
133
-
134
  start_time = time.time() # Start time for loading database
135
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
136
  self.embedding_model = self.create_embedding_model()
@@ -170,13 +167,23 @@ if __name__ == "__main__":
170
 
171
  with open("modules/config/config.yml", "r") as f:
172
  config = yaml.safe_load(f)
 
 
 
 
 
173
  print(config)
174
  print(f"Trying to create database with config: {config}")
175
  vector_db = VectorStoreManager(config)
176
  if config["vectorstore"]["load_from_HF"]:
177
- if config["vectorstore"]["db_option"] in RETRIEVER_HF_PATHS:
 
 
 
178
  vector_db.load_from_HF(
179
- HF_PATH=RETRIEVER_HF_PATHS[config["vectorstore"]["db_option"]]
 
 
180
  )
181
  else:
182
  # print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
@@ -189,7 +196,7 @@ if __name__ == "__main__":
189
  vector_db.create_database()
190
  print("Created database")
191
 
192
- print(f"Trying to load the database")
193
  vector_db = VectorStoreManager(config)
194
  vector_db.load_database()
195
  print("Loaded database")
 
1
  from modules.vectorstore.vectorstore import VectorStore
2
+ from modules.dataloader.helpers import get_urls_from_file
3
  from modules.dataloader.webpage_crawler import WebpageCrawler
4
  from modules.dataloader.data_loader import DataLoader
 
 
5
  from modules.vectorstore.embedding_model_loader import EmbeddingModelLoader
6
  import logging
7
  import os
 
47
  return logger
48
 
49
  def load_files(self):
 
50
  files = os.listdir(self.config["vectorstore"]["data_path"])
51
  files = [
52
  os.path.join(self.config["vectorstore"]["data_path"], file)
 
68
  return files, urls
69
 
70
  def create_embedding_model(self):
 
71
  self.logger.info("Creating embedding function")
72
  embedding_model_loader = EmbeddingModelLoader(self.config)
73
  embedding_model = embedding_model_loader.load_embedding_model()
 
98
  )
99
 
100
  def create_database(self):
 
101
  start_time = time.time() # Start time for creating database
102
  data_loader = DataLoader(self.config, self.logger)
103
  self.logger.info("Loading data")
 
107
  self.logger.info(f"Number of webpages: {len(webpages)}")
108
  if f"{self.config['vectorstore']['url_file_path']}" in files:
109
  files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
110
+ (
111
+ document_chunks,
112
+ document_names,
113
+ documents,
114
+ document_metadata,
115
+ ) = data_loader.get_chunks(files, webpages)
116
  num_documents = len(document_chunks)
117
  self.logger.info(f"Number of documents in the DB: {num_documents}")
118
+ metadata_keys = list(document_metadata[0].keys()) if document_metadata else []
119
  self.logger.info(f"Metadata keys: {metadata_keys}")
120
  self.logger.info("Completed loading data")
121
  self.initialize_database(
 
128
  )
129
 
130
  def load_database(self):
 
131
  start_time = time.time() # Start time for loading database
132
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
133
  self.embedding_model = self.create_embedding_model()
 
167
 
168
  with open("modules/config/config.yml", "r") as f:
169
  config = yaml.safe_load(f)
170
+ with open("modules/config/project_config.yml", "r") as f:
171
+ project_config = yaml.safe_load(f)
172
+
173
+ # combine the two configs
174
+ config.update(project_config)
175
  print(config)
176
  print(f"Trying to create database with config: {config}")
177
  vector_db = VectorStoreManager(config)
178
  if config["vectorstore"]["load_from_HF"]:
179
+ if (
180
+ config["vectorstore"]["db_option"]
181
+ in config["retriever"]["retriever_hf_paths"]
182
+ ):
183
  vector_db.load_from_HF(
184
+ HF_PATH=config["retriever"]["retriever_hf_paths"][
185
+ config["vectorstore"]["db_option"]
186
+ ]
187
  )
188
  else:
189
  # print(f"HF_PATH not available for {config['vectorstore']['db_option']}")
 
196
  vector_db.create_database()
197
  print("Created database")
198
 
199
+ print("Trying to load the database")
200
  vector_db = VectorStoreManager(config)
201
  vector_db.load_database()
202
  print("Loaded database")
code/public/avatars/{ai-tutor.png → ai_tutor.png} RENAMED
File without changes
code/public/space.jpg ADDED

Git LFS Details

  • SHA256: 9ed3f8e7fd9790c394bae59cd0e315742af862ed833e9f42906f36f140abbb07
  • Pointer size: 132 Bytes
  • Size of remote file: 2.68 MB
code/public/test.css CHANGED
@@ -13,10 +13,6 @@ a[href*='https://github.com/Chainlit/chainlit'] {
13
  border-radius: 50%; /* Maintain circular shape */
14
  }
15
 
16
- /* Hide the default image */
17
- .MuiAvatar-root.MuiAvatar-circular.css-m2icte .MuiAvatar-img.css-1hy9t21 {
18
- display: none;
19
- }
20
 
21
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 {
22
  background-image: url('/public/avatars/ai-tutor.png'); /* Replace with your custom image URL */
@@ -26,18 +22,3 @@ a[href*='https://github.com/Chainlit/chainlit'] {
26
  height: 40px; /* Ensure the dimensions match the original */
27
  border-radius: 50%; /* Maintain circular shape */
28
  }
29
-
30
- /* Hide the default image */
31
- .MuiAvatar-root.MuiAvatar-circular.css-v72an7 .MuiAvatar-img.css-1hy9t21 {
32
- display: none;
33
- }
34
-
35
- /* Hide the new chat button
36
- #new-chat-button {
37
- display: none;
38
- } */
39
-
40
- /* Hide the open sidebar button
41
- #open-sidebar-button {
42
- display: none;
43
- } */
 
13
  border-radius: 50%; /* Maintain circular shape */
14
  }
15
 
 
 
 
 
16
 
17
  .MuiAvatar-root.MuiAvatar-circular.css-v72an7 {
18
  background-image: url('/public/avatars/ai-tutor.png'); /* Replace with your custom image URL */
 
22
  height: 40px; /* Ensure the dimensions match the original */
23
  border-radius: 50%; /* Maintain circular shape */
24
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/templates/cooldown.html ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Cooldown Period | Terrier Tutor</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ font-family: 'Inter', sans-serif;
14
+ background-color: #f7f7f7;
15
+ background-image: url('https://www.transparenttextures.com/patterns/cubes.png');
16
+ background-repeat: repeat;
17
+ display: flex;
18
+ align-items: center;
19
+ justify-content: center;
20
+ height: 100vh;
21
+ color: #333;
22
+ }
23
+
24
+ .container {
25
+ background: rgba(255, 255, 255, 0.9);
26
+ border: 1px solid #ddd;
27
+ border-radius: 8px;
28
+ width: 100%;
29
+ max-width: 400px;
30
+ padding: 50px;
31
+ box-sizing: border-box;
32
+ text-align: center;
33
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
34
+ backdrop-filter: blur(10px);
35
+ -webkit-backdrop-filter: blur(10px);
36
+ }
37
+
38
+ .avatar {
39
+ width: 90px;
40
+ height: 90px;
41
+ border-radius: 50%;
42
+ margin-bottom: 25px;
43
+ border: 2px solid #ddd;
44
+ }
45
+
46
+ .container h1 {
47
+ margin-bottom: 15px;
48
+ font-size: 24px;
49
+ font-weight: 600;
50
+ color: #1a1a1a;
51
+ }
52
+
53
+ .container p {
54
+ font-size: 16px;
55
+ color: #4a4a4a;
56
+ margin-bottom: 30px;
57
+ line-height: 1.5;
58
+ }
59
+
60
+ .cooldown-message {
61
+ font-size: 16px;
62
+ color: #333;
63
+ margin-bottom: 30px;
64
+ }
65
+
66
+ .tokens-left {
67
+ font-size: 14px;
68
+ color: #333;
69
+ margin-bottom: 30px;
70
+ font-weight: 600;
71
+ }
72
+
73
+ .button {
74
+ padding: 12px 0;
75
+ margin: 12px 0;
76
+ font-size: 14px;
77
+ border-radius: 6px;
78
+ cursor: pointer;
79
+ width: 100%;
80
+ border: 1px solid #4285F4;
81
+ background-color: #fff;
82
+ color: #4285F4;
83
+ transition: background-color 0.3s ease, border-color 0.3s ease;
84
+ display: none;
85
+ }
86
+
87
+ .button.start-tutor {
88
+ display: none;
89
+ }
90
+
91
+ .button:hover {
92
+ background-color: #e0e0e0;
93
+ border-color: #357ae8;
94
+ }
95
+
96
+ .sign-out-button {
97
+ border: 1px solid #FF4C4C;
98
+ background-color: #fff;
99
+ color: #FF4C4C;
100
+ display: block;
101
+ }
102
+
103
+ .sign-out-button:hover {
104
+ background-color: #ffe6e6;
105
+ border-color: #e04343;
106
+ color: #e04343;
107
+ }
108
+
109
+ #countdown {
110
+ font-size: 14px;
111
+ color: #555;
112
+ margin-bottom: 20px;
113
+ }
114
+
115
+ .footer {
116
+ font-size: 12px;
117
+ color: #777;
118
+ margin-top: 20px;
119
+ }
120
+ </style>
121
+ </head>
122
+ <body>
123
+ <div class="container">
124
+ <img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
125
+ <h1>Hello, {{ username }}</h1>
126
+ <p>It seems like you need to wait a bit before starting a new session.</p>
127
+ <p class="cooldown-message">Time remaining until the cooldown period ends:</p>
128
+ <p id="countdown"></p>
129
+ <p class="tokens-left">Tokens Left: <span id="tokensLeft">{{ tokens_left }}</span></p>
130
+ <button id="startTutorBtn" class="button start-tutor" onclick="startTutor()">Start AI Tutor</button>
131
+ <form action="/logout" method="get">
132
+ <button type="submit" class="button sign-out-button">Sign Out</button>
133
+ </form>
134
+ <div class="footer">Reload the page to update token stats</div>
135
+ </div>
136
+ <script>
137
+ function startCountdown(endTime) {
138
+ const countdownElement = document.getElementById('countdown');
139
+ const startTutorBtn = document.getElementById('startTutorBtn');
140
+ const endTimeDate = new Date(endTime);
141
+
142
+ function updateCountdown() {
143
+ const now = new Date();
144
+ const timeLeft = endTimeDate.getTime() - now.getTime();
145
+
146
+ if (timeLeft <= 0) {
147
+ countdownElement.textContent = "Cooldown period has ended.";
148
+ startTutorBtn.style.display = "block";
149
+ } else {
150
+ const hours = Math.floor(timeLeft / 1000 / 60 / 60);
151
+ const minutes = Math.floor((timeLeft / 1000 / 60) % 60);
152
+ const seconds = Math.floor((timeLeft / 1000) % 60);
153
+ countdownElement.textContent = `${hours}h ${minutes}m ${seconds}s`;
154
+ }
155
+ }
156
+
157
+ updateCountdown();
158
+ setInterval(updateCountdown, 1000);
159
+ }
160
+
161
+ function startTutor() {
162
+ window.location.href = "/start-tutor";
163
+ }
164
+
165
+ function updateTokensLeft() {
166
+ fetch('/get-tokens-left')
167
+ .then(response => response.json())
168
+ .then(data => {
169
+ document.getElementById('tokensLeft').textContent = data.tokens_left;
170
+ })
171
+ .catch(error => console.error('Error fetching tokens:', error));
172
+ }
173
+
174
+ // Start the countdown
175
+ startCountdown("{{ cooldown_end_time }}");
176
+
177
+ // Update tokens left when the page loads
178
+ updateTokensLeft();
179
+ </script>
180
+ </body>
181
+ </html>
code/templates/dashboard.html ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Dashboard | Terrier Tutor</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ font-family: 'Inter', sans-serif;
14
+ background-color: #f7f7f7; /* Light gray background */
15
+ background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
16
+ background-repeat: repeat;
17
+ display: flex;
18
+ align-items: center;
19
+ justify-content: center;
20
+ height: 100vh;
21
+ color: #333;
22
+ }
23
+
24
+ .container {
25
+ background: rgba(255, 255, 255, 0.9);
26
+ border: 1px solid #ddd;
27
+ border-radius: 8px;
28
+ width: 100%;
29
+ max-width: 400px;
30
+ padding: 40px;
31
+ box-sizing: border-box;
32
+ text-align: center;
33
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
34
+ backdrop-filter: blur(10px);
35
+ -webkit-backdrop-filter: blur(10px);
36
+ }
37
+
38
+ .avatar {
39
+ width: 90px;
40
+ height: 90px;
41
+ border-radius: 50%;
42
+ margin-bottom: 20px;
43
+ border: 2px solid #ddd;
44
+ }
45
+
46
+ .container h1 {
47
+ margin-bottom: 20px;
48
+ font-size: 26px;
49
+ font-weight: 600;
50
+ color: #1a1a1a;
51
+ }
52
+
53
+ .container p {
54
+ font-size: 15px;
55
+ color: #4a4a4a;
56
+ margin-bottom: 25px;
57
+ line-height: 1.5;
58
+ }
59
+
60
+ .tokens-left {
61
+ font-size: 17px;
62
+ color: #333;
63
+ margin-bottom: 10px;
64
+ font-weight: 600;
65
+ }
66
+
67
+ .all-time-tokens {
68
+ font-size: 14px; /* Reduced font size */
69
+ color: #555;
70
+ margin-bottom: 30px;
71
+ font-weight: 500;
72
+ white-space: nowrap; /* Prevents breaking to a new line */
73
+ }
74
+
75
+ .button {
76
+ padding: 12px 0;
77
+ margin: 12px 0;
78
+ font-size: 15px;
79
+ border-radius: 6px;
80
+ cursor: pointer;
81
+ width: 100%;
82
+ border: 1px solid #4285F4; /* Button border color */
83
+ background-color: #fff; /* Button background color */
84
+ color: #4285F4; /* Button text color */
85
+ transition: background-color 0.3s ease, border-color 0.3s ease;
86
+ }
87
+
88
+ .button:hover {
89
+ background-color: #e0e0e0;
90
+ border-color: #357ae8; /* Darker blue for hover */
91
+ }
92
+
93
+ .start-button {
94
+ border: 1px solid #4285F4;
95
+ color: #4285F4;
96
+ background-color: #fff;
97
+ }
98
+
99
+ .start-button:hover {
100
+ background-color: #e0f0ff; /* Light blue on hover */
101
+ border-color: #357ae8; /* Darker blue for hover */
102
+ color: #357ae8; /* Blue text on hover */
103
+ }
104
+
105
+ .sign-out-button {
106
+ border: 1px solid #FF4C4C;
107
+ background-color: #fff;
108
+ color: #FF4C4C;
109
+ }
110
+
111
+ .sign-out-button:hover {
112
+ background-color: #ffe6e6; /* Light red on hover */
113
+ border-color: #e04343; /* Darker red for hover */
114
+ color: #e04343; /* Red text on hover */
115
+ }
116
+
117
+ .footer {
118
+ font-size: 12px;
119
+ color: #777;
120
+ margin-top: 25px;
121
+ }
122
+ </style>
123
+ </head>
124
+ <body>
125
+ <div class="container">
126
+ <img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
127
+ <h1>Welcome, {{ username }}</h1>
128
+ <p>Ready to start your AI tutoring session?</p>
129
+ <p class="tokens-left">Tokens Left: {{ tokens_left }}</p>
130
+ <p class="all-time-tokens">All-Time Tokens Allocated: {{ all_time_tokens_allocated }} / {{ total_tokens_allocated }}</p>
131
+ <form action="/start-tutor" method="post">
132
+ <button type="submit" class="button start-button">Start AI Tutor</button>
133
+ </form>
134
+ <form action="/logout" method="get">
135
+ <button type="submit" class="button sign-out-button">Sign Out</button>
136
+ </form>
137
+ <div class="footer">Reload the page to update token stats</div>
138
+ </div>
139
+ <script>
140
+ let token = "{{ jwt_token }}";
141
+ console.log("Token: ", token);
142
+ localStorage.setItem('token', token);
143
+ </script>
144
+ </body>
145
+ </html>
code/templates/error.html ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Error | Terrier Tutor</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ font-family: 'Inter', sans-serif;
14
+ background-color: #f7f7f7; /* Light gray background */
15
+ background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
16
+ background-repeat: repeat;
17
+ display: flex;
18
+ align-items: center;
19
+ justify-content: center;
20
+ height: 100vh;
21
+ color: #333;
22
+ }
23
+
24
+ .container {
25
+ background: rgba(255, 255, 255, 0.9);
26
+ border: 1px solid #ddd;
27
+ border-radius: 8px;
28
+ width: 100%;
29
+ max-width: 400px;
30
+ padding: 50px;
31
+ box-sizing: border-box;
32
+ text-align: center;
33
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
34
+ backdrop-filter: blur(10px);
35
+ -webkit-backdrop-filter: blur(10px);
36
+ }
37
+
38
+ .container h1 {
39
+ margin-bottom: 20px;
40
+ font-size: 26px;
41
+ font-weight: 600;
42
+ color: #1a1a1a;
43
+ }
44
+
45
+ .container p {
46
+ font-size: 18px;
47
+ color: #4a4a4a;
48
+ margin-bottom: 35px;
49
+ line-height: 1.5;
50
+ }
51
+
52
+ .button {
53
+ padding: 14px 0;
54
+ margin: 12px 0;
55
+ font-size: 16px;
56
+ border-radius: 6px;
57
+ cursor: pointer;
58
+ width: 100%;
59
+ border: 1px solid #ccc;
60
+ background-color: #007BFF;
61
+ color: #fff;
62
+ transition: background-color 0.3s ease, border-color 0.3s ease;
63
+ }
64
+
65
+ .button:hover {
66
+ background-color: #0056b3;
67
+ border-color: #0056b3;
68
+ }
69
+
70
+ .error-box {
71
+ background-color: #2d2d2d;
72
+ color: #fff;
73
+ padding: 10px;
74
+ margin-top: 20px;
75
+ font-family: 'Courier New', Courier, monospace;
76
+ text-align: left;
77
+ overflow-x: auto;
78
+ white-space: pre-wrap;
79
+ border-radius: 5px;
80
+ }
81
+ </style>
82
+ </head>
83
+ <body>
84
+ <div class="container">
85
+ <h1>Oops! Something went wrong...</h1>
86
+ <p>An unexpected error occurred. The details are below:</p>
87
+ <div class="error-box">
88
+ <code>{{ error }}</code>
89
+ </div>
90
+ <form action="/" method="get">
91
+ <button type="submit" class="button">Return to Home</button>
92
+ </form>
93
+ </div>
94
+ </body>
95
+ </html>
code/templates/error_404.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>404 - Not Found</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ font-family: 'Inter', sans-serif;
14
+ background-color: #f7f7f7; /* Light gray background */
15
+ background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
16
+ background-repeat: repeat;
17
+ display: flex;
18
+ align-items: center;
19
+ justify-content: center;
20
+ height: 100vh;
21
+ color: #333;
22
+ }
23
+
24
+ .container {
25
+ background: rgba(255, 255, 255, 0.9);
26
+ border: 1px solid #ddd;
27
+ border-radius: 8px;
28
+ width: 100%;
29
+ max-width: 400px;
30
+ padding: 50px;
31
+ box-sizing: border-box;
32
+ text-align: center;
33
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
34
+ backdrop-filter: blur(10px);
35
+ -webkit-backdrop-filter: blur(10px);
36
+ }
37
+
38
+ .container h1 {
39
+ margin-bottom: 20px;
40
+ font-size: 26px;
41
+ font-weight: 600;
42
+ color: #1a1a1a;
43
+ }
44
+
45
+ .container p {
46
+ font-size: 18px;
47
+ color: #4a4a4a;
48
+ margin-bottom: 35px;
49
+ line-height: 1.5;
50
+ }
51
+
52
+ .button {
53
+ padding: 14px 0;
54
+ margin: 12px 0;
55
+ font-size: 16px;
56
+ border-radius: 6px;
57
+ cursor: pointer;
58
+ width: 100%;
59
+ border: 1px solid #ccc;
60
+ background-color: #007BFF;
61
+ color: #fff;
62
+ transition: background-color 0.3s ease, border-color 0.3s ease;
63
+ }
64
+
65
+ .button:hover {
66
+ background-color: #0056b3;
67
+ border-color: #0056b3;
68
+ }
69
+ </style>
70
+ </head>
71
+ <body>
72
+ <div class="container">
73
+ <h1>You have ventured into the abyss...</h1>
74
+ <p>To get back to reality, click the button below.</p>
75
+ <form action="/" method="get">
76
+ <button type="submit" class="button">Return to Home</button>
77
+ </form>
78
+ </div>
79
+ </body>
80
+ </html>
code/templates/login.html ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Login | Terrier Tutor</title>
7
+ <style>
8
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
9
+
10
+ body, html {
11
+ margin: 0;
12
+ padding: 0;
13
+ font-family: 'Inter', sans-serif;
14
+ background-color: #f7f7f7; /* Light gray background */
15
+ background-image: url('https://www.transparenttextures.com/patterns/cubes.png'); /* Subtle geometric pattern */
16
+ background-repeat: repeat;
17
+ display: flex;
18
+ align-items: center;
19
+ justify-content: center;
20
+ height: 100vh;
21
+ color: #333;
22
+ }
23
+
24
+ .container {
25
+ background: rgba(255, 255, 255, 0.9);
26
+ border: 1px solid #ddd;
27
+ border-radius: 8px;
28
+ width: 100%;
29
+ max-width: 400px;
30
+ padding: 50px;
31
+ box-sizing: border-box;
32
+ text-align: center;
33
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
34
+ backdrop-filter: blur(10px);
35
+ -webkit-backdrop-filter: blur(10px);
36
+ }
37
+
38
+ .avatar {
39
+ width: 90px;
40
+ height: 90px;
41
+ border-radius: 50%;
42
+ margin-bottom: 25px;
43
+ border: 2px solid #ddd;
44
+ }
45
+
46
+ .container h1 {
47
+ margin-bottom: 15px;
48
+ font-size: 24px;
49
+ font-weight: 600;
50
+ color: #1a1a1a;
51
+ }
52
+
53
+ .container p {
54
+ font-size: 16px;
55
+ color: #4a4a4a;
56
+ margin-bottom: 30px;
57
+ line-height: 1.5;
58
+ }
59
+
60
+ .button {
61
+ padding: 12px 0;
62
+ margin: 12px 0;
63
+ font-size: 14px;
64
+ border-radius: 6px;
65
+ cursor: pointer;
66
+ width: 100%;
67
+ border: 1px solid #4285F4; /* Google button border color */
68
+ background-color: #fff; /* Guest button color */
69
+ color: #4285F4; /* Google button text color */
70
+ transition: background-color 0.3s ease, border-color 0.3s ease;
71
+ }
72
+
73
+ .button:hover {
74
+ background-color: #e0f0ff; /* Light blue on hover */
75
+ border-color: #357ae8; /* Darker blue for hover */
76
+ color: #357ae8; /* Blue text on hover */
77
+ }
78
+
79
+ .footer {
80
+ margin-top: 40px;
81
+ font-size: 15px;
82
+ color: #666;
83
+ text-align: center; /* Center the text in the footer */
84
+ }
85
+
86
+ .footer a {
87
+ color: #333;
88
+ text-decoration: none;
89
+ font-weight: 500;
90
+ display: inline-flex;
91
+ align-items: center;
92
+ justify-content: center; /* Center the content of the links */
93
+ transition: color 0.3s ease;
94
+ margin-bottom: 8px;
95
+ width: 100%; /* Make the link block level */
96
+ }
97
+
98
+ .footer a:hover {
99
+ color: #000;
100
+ }
101
+
102
+ .footer svg {
103
+ margin-right: 8px;
104
+ fill: currentColor;
105
+ }
106
+ </style>
107
+ </head>
108
+ <body>
109
+ <div class="container">
110
+ <img src="/public/avatars/ai_tutor.png" alt="AI Tutor Avatar" class="avatar">
111
+ <h1>Terrier Tutor</h1>
112
+ <p>Welcome to the DS598 AI Tutor. Please sign in to continue.</p>
113
+ <form action="/login/google" method="get">
114
+ <button type="submit" class="button">Sign in with Google</button>
115
+ </form>
116
+ <div class="footer">
117
+ <a href="{{ GITHUB_REPO }}" target="_blank">
118
+ <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24">
119
+ <path d="M12 .5C5.596.5.5 5.596.5 12c0 5.098 3.292 9.414 7.852 10.94.574.105.775-.249.775-.553 0-.272-.01-1.008-.015-1.98-3.194.694-3.87-1.544-3.87-1.544-.521-1.324-1.273-1.676-1.273-1.676-1.04-.714.079-.7.079-.7 1.148.08 1.75 1.181 1.75 1.181 1.022 1.752 2.683 1.246 3.34.954.104-.74.4-1.246.73-1.533-2.551-.292-5.234-1.276-5.234-5.675 0-1.253.447-2.277 1.181-3.079-.12-.293-.51-1.47.113-3.063 0 0 .96-.307 3.15 1.174.913-.255 1.892-.383 2.867-.388.975.005 1.954.133 2.868.388 2.188-1.481 3.147-1.174 3.147-1.174.624 1.593.233 2.77.114 3.063.735.802 1.18 1.826 1.18 3.079 0 4.407-2.688 5.38-5.248 5.668.413.354.782 1.049.782 2.113 0 1.526-.014 2.757-.014 3.132 0 .307.198.662.783.553C20.21 21.411 23.5 17.096 23.5 12c0-6.404-5.096-11.5-11.5-11.5z"/>
120
+ </svg>
121
+ View on GitHub
122
+ </a>
123
+ <a href="{{ DOCS_WEBSITE }}" target="_blank">
124
+ <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24">
125
+ <path d="M19 2H8c-1.103 0-2 .897-2 2v16c0 1.103.897 2 2 2h12c1.103 0 2-.897 2-2V7l-5-5zm0 2l.001 4H14V4h5zm-1 14H9V4h4v6h6v8zM7 4H6v16c0 1.654 1.346 3 3 3h9v-2H9c-.551 0-1-.449-1-1V4z"/>
126
+ </svg>
127
+ View Docs
128
+ </a>
129
+ </div>
130
+ </div>
131
+ </body>
132
+ </html>
code/templates/logout.html ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Logout</title>
5
+ <script>
6
+ window.onload = function() {
7
+ fetch('/chainlit_tutor/logout', {
8
+ method: 'POST',
9
+ credentials: 'include' // Ensure cookies are sent
10
+ }).then(() => {
11
+ window.location.href = '/';
12
+ }).catch(error => {
13
+ console.error('Logout failed:', error);
14
+ });
15
+ };
16
+ </script>
17
+ </head>
18
+ <body>
19
+ <p>Logging out... If you are not redirected, <a href="/">click here</a>.</p>
20
+ </body>
21
+ </html>
docs/README.md DELETED
@@ -1,51 +0,0 @@
1
- # Documentation
2
-
3
- ## File Structure:
4
- - `docs/` - Documentation files
5
- - `code/` - Code files
6
- - `storage/` - Storage files
7
- - `vectorstores/` - Vector Databases
8
- - `.env` - Environment Variables
9
- - `Dockerfile` - Dockerfile for Hugging Face
10
- - `.chainlit` - Chainlit Configuration
11
- - `chainlit.md` - Chainlit README
12
- - `README.md` - Repository README
13
- - `.gitignore` - Gitignore file
14
- - `requirements.txt` - Python Requirements
15
- - `.gitattributes` - Gitattributes file
16
-
17
- ## Code Structure
18
-
19
- - `code/main.py` - Main Chainlit App
20
- - `code/config.yaml` - Configuration File to set Embedding related, Vector Database related, and Chat Model related parameters.
21
- - `code/modules/vector_db.py` - Vector Database Creation
22
- - `code/modules/chat_model_loader.py` - Chat Model Loader (Creates the Chat Model)
23
- - `code/modules/constants.py` - Constants (Loads the Environment Variables, Prompts, Model Paths, etc.)
24
- - `code/modules/data_loader.py` - Loads and Chunks the Data
25
- - `code/modules/embedding_model.py` - Creates the Embedding Model to Embed the Data
26
- - `code/modules/llm_tutor.py` - Creates the RAG LLM Tutor
27
- - The Function `qa_bot()` loads the vector database and the chat model, and sets the prompt to pass to the chat model.
28
- - `code/modules/helpers.py` - Helper Functions
29
-
30
- ## Storage and Vectorstores
31
-
32
- - `storage/data/` - Data Storage (Put your pdf files under this directory, and urls in the urls.txt file)
33
- - `storage/models/` - Model Storage (Put your local LLMs under this directory)
34
-
35
- - `vectorstores/` - Vector Databases (Stores the Vector Databases generated from `code/modules/vector_db.py`)
36
-
37
-
38
- ## Useful Configurations
39
- set these in `code/config.yaml`:
40
- * ``["embedding_options"]["embedd_files"]`` - If set to True, embeds the files from the storage directory everytime you run the chainlit command. If set to False, uses the stored vector database.
41
- * ``["embedding_options"]["expand_urls"]`` - If set to True, gets and reads the data from all the links under the url provided. If set to False, only reads the data in the url provided.
42
- * ``["embedding_options"]["search_top_k"]`` - Number of sources that the retriever returns
43
- * ``["llm_params]["use_history"]`` - Whether to use history in the prompt or not
44
- * ``["llm_params]["memory_window"]`` - Number of interactions to keep a track of in the history
45
-
46
-
47
- ## LlamaCpp
48
- * https://python.langchain.com/docs/integrations/llms/llamacpp
49
-
50
- ## Hugging Face Models
51
- * Download the ``.gguf`` files for your Local LLM from Hugging Face (Example: https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/contribute.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 💡 **Please ensure formatting, linting, and security checks pass before submitting a pull request**
2
+
3
+ ## Code Formatting
4
+
5
+ The codebase is formatted using [black](https://github.com/psf/black)
6
+
7
+ To format the codebase, run the following command:
8
+
9
+ ```bash
10
+ black .
11
+ ```
12
+
13
+ Please ensure that the code is formatted before submitting a pull request.
14
+
15
+ ## Linting
16
+
17
+ The codebase is linted using [flake8](https://flake8.pycqa.org/en/latest/)
18
+
19
+ To view the linting errors, run the following command:
20
+
21
+ ```bash
22
+ flake8 .
23
+ ```
24
+
25
+ ## Security and Vulnerabilities
26
+
27
+ The codebase is scanned for security vulnerabilities using [bandit](https://github.com/PyCQA/bandit)
28
+
29
+ To scan the codebase for security vulnerabilities, run the following command:
30
+
31
+ ```bash
32
+ bandit -r .
33
+ ```
docs/setup.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initial Setup
2
+
3
+ ⚠️ **Create the .env file inside the `code/` directory.**
4
+
5
+ ## Python Environment
6
+
7
+ Python Version: 3.11
8
+
9
+ Create a virtual environment and install the required packages:
10
+
11
+ ```bash
12
+ conda create -n ai_tutor python=3.11
13
+ conda activate ai_tutor
14
+ pip install -r requirements.txt
15
+ ```
16
+
17
+ ## Code Formatting
18
+
19
+ The codebase is formatted using [black](https://github.com/psf/black), and if making changes to the codebase, ensure that the code is formatted before submitting a pull request. More instructions can be found in `docs/contribute.md`.
20
+
21
+ ## Google OAuth 2.0 Client ID and Secret
22
+
23
+ To set up the Google OAuth 2.0 Client ID and Secret, follow these steps:
24
+
25
+ 1. Go to the [Google Cloud Console](https://console.cloud.google.com/apis/credentials).
26
+ 2. Create a new project or select an existing one.
27
+ 3. Navigate to the "Credentials" page.
28
+ 4. Click on "Create Credentials" and select "OAuth 2.0 Client ID".
29
+ 5. Configure the OAuth consent screen if you haven't already.
30
+ 6. Choose "Web application" as the application type.
31
+ 7. Configure the redirect URIs as needed.
32
+ 8. Copy the generated `Client ID` and `Client Secret`.
33
+
34
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
35
+
36
+ ```bash
37
+ OAUTH_GOOGLE_CLIENT_ID=<your_client_id>
38
+ OAUTH_GOOGLE_CLIENT_SECRET=<your_client_secret>
39
+ ```
40
+
41
+ ## Literal AI API Key
42
+
43
+ To obtain the Literal AI API key:
44
+
45
+ 1. Sign up or log in to [Literal AI](https://cloud.getliteral.ai/).
46
+ 2. Navigate to the API Keys section under your account settings.
47
+ 3. Create a new API key if necessary and copy it.
48
+
49
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
50
+
51
+ ```bash
52
+ LITERAL_API_KEY_LOGGING=<your_api_key>
53
+ LITERAL_API_URL=https://cloud.getliteral.ai
54
+ ```
55
+
56
+ ## LlamaCloud API Key
57
+
58
+ To obtain the LlamaCloud API Key:
59
+
60
+ 1. Go to [LlamaCloud](https://cloud.llamaindex.ai/).
61
+ 2. Sign up or log in to your account.
62
+ 3. Navigate to the API section and generate a new API key if necessary.
63
+
64
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
65
+
66
+ ```bash
67
+ LLAMA_CLOUD_API_KEY=<your_api_key>
68
+ ```
69
+
70
+ ## Hugging Face Access Token
71
+
72
+ To obtain your Hugging Face access token:
73
+
74
+ 1. Go to [Hugging Face settings](https://huggingface.co/settings/tokens).
75
+ 2. Log in or create an account.
76
+ 3. Generate a new token or use an existing one.
77
+
78
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
79
+
80
+ ```bash
81
+ HUGGINGFACE_TOKEN=<your-huggingface-token>
82
+ ```
83
+
84
+ ## Chainlit Authentication Secret
85
+
86
+ You must provide a JWT secret in the environment to use authentication. Run `chainlit create-secret` to generate one.
87
+
88
+ ```bash
89
+ chainlit create-secret
90
+ ```
91
+
92
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
93
+
94
+ ```bash
95
+ CHAINLIT_AUTH_SECRET=<your_jwt_secret>
96
+ CHAINLIT_URL=<your_chainlit_url> # Example: CHAINLIT_URL=http://localhost:8000
97
+ ```
98
+
99
+ ## OpenAI API Key
100
+
101
+ Set the following in the .env file (if running locally) or in secrets (if running on Hugging Face Spaces):
102
+
103
+ ```bash
104
+ OPENAI_API_KEY=<your_openai_api_key>
105
+ ```
106
+
107
+ ## In a Nutshell
108
+
109
+ Your .env file (secrets in HuggingFace) should look like this:
110
+
111
+ ```bash
112
+ CHAINLIT_AUTH_SECRET=<your_jwt_secret>
113
+ OPENAI_API_KEY=<your_openai_api_key>
114
+ HUGGINGFACE_TOKEN=<your-huggingface-token>
115
+ LITERAL_API_KEY_LOGGING=<your_api_key>
116
+ LITERAL_API_URL=<https://cloud.getliteral.ai>
117
+ OAUTH_GOOGLE_CLIENT_ID=<your_client_id>
118
+ OAUTH_GOOGLE_CLIENT_SECRET=<your_client_secret>
119
+ LLAMA_CLOUD_API_KEY=<your_api_key>
120
+ CHAINLIT_URL=<your_chainlit_url>
121
+ ```
122
+
123
+
124
+ # Configuration
125
+
126
+ The configuration file `code/modules/config.yaml` contains the parameters that control the behaviour of your app.
127
+ The configuration file `code/modules/project_config.yaml` contains project-specific parameters.
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.black]
2
+ line-length = 88
requirements.txt CHANGED
@@ -22,4 +22,15 @@ umap-learn
22
  llama-cpp-python
23
  pymupdf
24
  websockets
25
- langchain-openai
 
 
 
 
 
 
 
 
 
 
 
 
22
  llama-cpp-python
23
  pymupdf
24
  websockets
25
+ langchain-openai
26
+ langchain-experimental
27
+ html2text
28
+ PyPDF2
29
+ pdf2image
30
+ black
31
+ flake8
32
+ bandit
33
+ fastapi
34
+ google-auth
35
+ google-auth-oauthlib
36
+ Jinja2