Eric Guan commited on
Commit
de080b7
·
1 Parent(s): 7490a78

Initial Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ .env
__pycache__/passwords.cpython-311.pyc ADDED
Binary file (363 Bytes). View file
 
__pycache__/passwords.cpython-312.pyc ADDED
Binary file (357 Bytes). View file
 
__pycache__/rag_model.cpython-311.pyc ADDED
Binary file (4.33 kB). View file
 
__pycache__/rag_model.cpython-312.pyc ADDED
Binary file (3.66 kB). View file
 
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ from rag_model import *
5
+
6
+
7
+ @st.cache_resource
8
+ def load_image_model():
9
+ return pipeline("image-classification", model="Heem2/wound-image-classification")
10
+
11
+ pipeline = load_image_model()
12
+
13
+ st.title("FirstAid-AI")
14
+
15
+ # Initialize chat history
16
+ if "messages" not in st.session_state:
17
+ st.session_state.messages = []
18
+
19
+ # Upload an image of a wound
20
+ file = st.file_uploader("Upload an image of your wound")
21
+
22
+ # Reset chat history if no file is uploaded
23
+ if file is None:
24
+ st.session_state.messages = []
25
+
26
+ if file is not None:
27
+ # Display the image and predictions
28
+ col1, col2 = st.columns(2)
29
+ image = Image.open(file)
30
+ col1.image(image, use_container_width=True)
31
+
32
+ # Classify the image
33
+ predictions = pipeline(image)
34
+ detected_wound = predictions[0]['label']
35
+ col2.header("Detected Wound")
36
+ for p in predictions:
37
+ col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%")
38
+
39
+ # Initial advice for wound advice
40
+ if not st.session_state.messages:
41
+ initial_query = f"Provide treatment advice for a {detected_wound} wound"
42
+ initial_response = rag_chain.invoke(initial_query)
43
+ st.session_state.messages.append({"role": "assistant", "content": initial_response})
44
+
45
+ # Display chat messages from history
46
+ for message in st.session_state.messages:
47
+ with st.chat_message(message["role"]):
48
+ st.markdown(message["content"])
49
+
50
+ # Accept user input if an image is uploaded
51
+ if file is not None and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")):
52
+ # Display user message in chat
53
+ with st.chat_message("user"):
54
+ st.markdown(prompt)
55
+ # Add user message to chat history
56
+ st.session_state.messages.append({"role": "user", "content": prompt})
57
+
58
+ # Prepare the conversation history for rag_chain
59
+ conversation_history = "\n".join(
60
+ f"{message['role']}: {message['content']}" for message in st.session_state.messages
61
+ )
62
+
63
+ # Generate response from rag_chain
64
+ query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}"
65
+ response = rag_chain.invoke(query)
66
+
67
+ # Display assistant response in chat message container
68
+ with st.chat_message("assistant"):
69
+ st.markdown(response)
70
+
71
+ # Add assistant response to chat history
72
+ st.session_state.messages.append({"role": "assistant", "content": response})
rag_model.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries.
2
+
3
+ # File loading and environment variables.
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ # HuggingFace LLM.
8
+ from huggingface_hub import InferenceClient
9
+
10
+ # Langchain.
11
+ from langchain.document_loaders import TextLoader
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
16
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
17
+
18
+ # MongoDB.
19
+ from pymongo import MongoClient
20
+
21
+ # Function type hints.
22
+ from typing import Dict, Any
23
+
24
+ # Streamlit
25
+ import streamlit as st
26
+
27
+ # Load the environment variables from the .env file
28
+ load_dotenv()
29
+
30
+ # Load the environment variables
31
+ MONGO_URI = os.getenv("MONGO_URI")
32
+ HF_TOKEN = os.getenv("HF_TOKEN")
33
+
34
+ # Setup Vector Store and MongoDB Atlas connection.
35
+
36
+ # Connect to MongoDB Atlas cluster using the connection string.
37
+ DB_NAME = "ericguan04"
38
+ COLLECTION_NAME = "first_aid_intents"
39
+ vector_search_index = "vector_index"
40
+
41
+ @st.cache_resource
42
+ def get_mongodb_collection():
43
+ # Connect to MongoDB Atlas cluster using the connection string.
44
+ cluster = MongoClient(MONGO_URI)
45
+ # Connect to the specific collection in the database.
46
+ return cluster[DB_NAME][COLLECTION_NAME]
47
+
48
+ MONGODB_COLLECTION = get_mongodb_collection()
49
+
50
+ @st.cache_resource
51
+ def load_embedding_model():
52
+ return HuggingFaceInferenceAPIEmbeddings(
53
+ api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2"
54
+ )
55
+
56
+ embedding_model = load_embedding_model()
57
+
58
+ vector_search = MongoDBAtlasVectorSearch.from_connection_string(
59
+ connection_string=MONGO_URI,
60
+ namespace=f"{DB_NAME}.{COLLECTION_NAME}",
61
+ embedding=embedding_model,
62
+ index_name=vector_search_index,
63
+ )
64
+
65
+ # k to search for only the X most relevant documents.
66
+ k = 10
67
+
68
+ # score_threshold to use only documents with a relevance score above 0.80.
69
+ score_threshold = 0.80
70
+
71
+ # Build your retriever
72
+ retriever_1 = vector_search.as_retriever(
73
+ search_type = "similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
74
+ search_kwargs = {"k": k, "score_threshold": score_threshold}
75
+ )
76
+
77
+
78
+ # Initialize Hugging Face client
79
+ hf_client = InferenceClient(api_key=HF_TOKEN)
80
+
81
+ # Define the prompt template
82
+ prompt = PromptTemplate.from_template(
83
+ """Use the following pieces of context to answer the question at the end.
84
+
85
+ START OF CONTEXT:
86
+ {context}
87
+ END OF CONTEXT:
88
+
89
+ START OF QUESTION:
90
+ {question}
91
+ END OF QUESTION:
92
+
93
+ If you do not know the answer, just say that you do not know.
94
+ NEVER assume things.
95
+ """
96
+ )
97
+
98
+
99
+ # Formatting the retrieved documents before inserting them in the system prompt template.
100
+ def format_docs(docs):
101
+ return "\n\n".join(doc.page_content for doc in docs)
102
+
103
+ @st.cache_resource
104
+ def generate_response(input_dict: Dict[str, Any]) -> str:
105
+ formatted_prompt = prompt.format(**input_dict)
106
+ # print(formatted_prompt)
107
+
108
+ ## THIS IS YOUR LLM
109
+ response = hf_client.chat.completions.create(
110
+ model="Qwen/Qwen2.5-1.5B-Instruct",
111
+ messages=[{
112
+ "role": "system",
113
+ "content": formatted_prompt
114
+ },{
115
+ "role": "user",
116
+ "content": input_dict["question"]
117
+ }],
118
+ max_tokens=1000,
119
+ temperature=0.2,
120
+ )
121
+
122
+ return response.choices[0].message.content
123
+
124
+
125
+ # Build the chain with retriever_1.
126
+ rag_chain = (
127
+ {
128
+ "context": retriever_1 | RunnableLambda(format_docs),
129
+ "question": RunnablePassthrough()
130
+ }
131
+ | RunnableLambda(generate_response)
132
+ )
requirements.txt ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.4.3
2
+ aiohttp==3.11.5
3
+ aiosignal==1.3.1
4
+ altair==5.4.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ attrs==24.2.0
8
+ blinker==1.9.0
9
+ cachetools==5.5.0
10
+ certifi==2024.8.30
11
+ charset-normalizer==3.4.0
12
+ click==8.1.7
13
+ dataclasses-json==0.6.7
14
+ dnspython==2.7.0
15
+ filelock==3.16.1
16
+ frozenlist==1.5.0
17
+ fsspec==2024.10.0
18
+ gitdb==4.0.11
19
+ GitPython==3.1.43
20
+ h11==0.14.0
21
+ httpcore==1.0.7
22
+ httpx==0.27.2
23
+ httpx-sse==0.4.0
24
+ huggingface-hub==0.26.2
25
+ idna==3.10
26
+ Jinja2==3.1.4
27
+ jsonpatch==1.33
28
+ jsonpointer==3.0.0
29
+ jsonschema==4.23.0
30
+ jsonschema-specifications==2024.10.1
31
+ langchain==0.3.7
32
+ langchain-community==0.3.7
33
+ langchain-core==0.3.19
34
+ langchain-text-splitters==0.3.2
35
+ langsmith==0.1.143
36
+ markdown-it-py==3.0.0
37
+ MarkupSafe==3.0.2
38
+ marshmallow==3.23.1
39
+ mdurl==0.1.2
40
+ mpmath==1.3.0
41
+ multidict==6.1.0
42
+ mypy-extensions==1.0.0
43
+ narwhals==1.14.1
44
+ networkx==3.4.2
45
+ numpy==1.26.4
46
+ orjson==3.10.11
47
+ packaging==24.2
48
+ pandas==2.2.3
49
+ pillow==11.0.0
50
+ propcache==0.2.0
51
+ protobuf==5.28.3
52
+ pyarrow==18.0.0
53
+ pydantic==2.9.2
54
+ pydantic-settings==2.6.1
55
+ pydantic_core==2.23.4
56
+ pydeck==0.9.1
57
+ Pygments==2.18.0
58
+ pymongo==4.10.1
59
+ python-dateutil==2.9.0.post0
60
+ python-dotenv==1.0.1
61
+ pytz==2024.2
62
+ PyYAML==6.0.2
63
+ referencing==0.35.1
64
+ regex==2024.11.6
65
+ requests==2.32.3
66
+ requests-toolbelt==1.0.0
67
+ rich==13.9.4
68
+ rpds-py==0.21.0
69
+ safetensors==0.4.5
70
+ six==1.16.0
71
+ smmap==5.0.1
72
+ sniffio==1.3.1
73
+ SQLAlchemy==2.0.35
74
+ streamlit==1.40.1
75
+ sympy==1.13.1
76
+ tenacity==9.0.0
77
+ tokenizers==0.20.3
78
+ toml==0.10.2
79
+ torch==2.5.1
80
+ tornado==6.4.1
81
+ tqdm==4.67.0
82
+ transformers==4.46.3
83
+ typing-inspect==0.9.0
84
+ typing_extensions==4.12.2
85
+ tzdata==2024.2
86
+ urllib3==2.2.3
87
+ yarl==1.17.2