Update rag_utils.py
Browse files- rag_utils.py +160 -128
rag_utils.py
CHANGED
@@ -1,136 +1,168 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
|
5 |
# Configure logging
|
6 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
"
|
26 |
-
|
27 |
-
"
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
"
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
"
|
79 |
-
|
80 |
-
"
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
"
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
}
|
116 |
-
]
|
117 |
-
}
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
for damage_type, cases in KNOWLEDGE_BASE.items():
|
131 |
-
for case in cases:
|
132 |
-
for key in required_keys:
|
133 |
-
if key not in case:
|
134 |
-
logging.error(f"Missing required field '{key}' in {damage_type}")
|
135 |
-
raise ValueError(f"Missing required field '{key}' in {damage_type}")
|
136 |
-
logging.info("Knowledge base validation passed.")
|
|
|
1 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
2 |
+
from langchain.vectorstores import FAISS
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from langchain.chains import RetrievalQA
|
5 |
+
from langchain.prompts import PromptTemplate
|
6 |
+
from langchain.llms import HuggingFaceHub
|
7 |
+
import os
|
8 |
import logging
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
12 |
|
13 |
+
class RAGSystem:
|
14 |
+
def __init__(self):
|
15 |
+
try:
|
16 |
+
self.embeddings = HuggingFaceEmbeddings(
|
17 |
+
model_name="sentence-transformers/all-mpnet-base-v2"
|
18 |
+
)
|
19 |
+
self.vector_store = None
|
20 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
21 |
+
chunk_size=500,
|
22 |
+
chunk_overlap=50
|
23 |
+
)
|
24 |
+
# Initialize HuggingFace model for text generation
|
25 |
+
self.llm = HuggingFaceHub(
|
26 |
+
repo_id="google/flan-t5-large",
|
27 |
+
task="text-generation",
|
28 |
+
model_kwargs={"temperature": 0.7, "max_length": 512}
|
29 |
+
)
|
30 |
+
logging.info("RAG system initialized successfully.")
|
31 |
+
except Exception as e:
|
32 |
+
logging.error(f"Failed to initialize RAG system: {str(e)}")
|
33 |
+
raise e
|
34 |
+
|
35 |
+
def initialize_knowledge_base(self, knowledge_base):
|
36 |
+
"""Initialize vector store with enhanced construction knowledge"""
|
37 |
+
try:
|
38 |
+
documents = []
|
39 |
+
# Validate knowledge base
|
40 |
+
self._validate_knowledge_base(knowledge_base)
|
41 |
+
|
42 |
+
# Add expert insights and case studies
|
43 |
+
expert_insights = self._generate_expert_insights(knowledge_base)
|
44 |
+
case_studies = self._generate_case_studies()
|
45 |
+
|
46 |
+
for damage_type, cases in knowledge_base.items():
|
47 |
+
for case in cases:
|
48 |
+
# Combine basic info with expert insights
|
49 |
+
relevant_insight = expert_insights.get(damage_type, "")
|
50 |
+
relevant_cases = case_studies.get(damage_type, "")
|
51 |
+
|
52 |
+
doc_text = f"""
|
53 |
+
Damage Type: {damage_type}
|
54 |
+
Severity: {case['severity']}
|
55 |
+
Description: {case['description']}
|
56 |
+
Technical Details: {case['description']}
|
57 |
+
Expert Insight: {relevant_insight}
|
58 |
+
Case Studies: {relevant_cases}
|
59 |
+
Repair Methods: {', '.join(case['repair_method'])}
|
60 |
+
Cost Considerations: {case['estimated_cost']}
|
61 |
+
Implementation Timeline: {case['timeframe']}
|
62 |
+
Location Specifics: {case['location']}
|
63 |
+
Required Expertise Level: {case['required_expertise']}
|
64 |
+
Emergency Protocol: {case['immediate_action']}
|
65 |
+
Preventive Measures: {case['prevention']}
|
66 |
+
Long-term Implications: Analysis of long-term structural integrity impact
|
67 |
+
Environmental Factors: Consideration of environmental conditions
|
68 |
+
"""
|
69 |
+
documents.append(doc_text)
|
70 |
+
|
71 |
+
splits = self.text_splitter.create_documents(documents)
|
72 |
+
self.vector_store = FAISS.from_documents(splits, self.embeddings)
|
73 |
+
|
74 |
+
# Initialize QA chain
|
75 |
+
self.qa_chain = RetrievalQA.from_chain_type(
|
76 |
+
llm=self.llm,
|
77 |
+
chain_type="stuff",
|
78 |
+
retriever=self.vector_store.as_retriever(),
|
79 |
+
chain_type_kwargs={
|
80 |
+
"prompt": self._get_qa_prompt()
|
81 |
+
}
|
82 |
+
)
|
83 |
+
logging.info("Knowledge base initialized successfully.")
|
84 |
+
except Exception as e:
|
85 |
+
logging.error(f"Failed to initialize knowledge base: {str(e)}")
|
86 |
+
raise e
|
87 |
+
|
88 |
+
def _validate_knowledge_base(self, knowledge_base):
|
89 |
+
"""Validate the structure of the knowledge base."""
|
90 |
+
required_keys = ['severity', 'description', 'repair_method', 'estimated_cost', 'timeframe', 'location', 'required_expertise', 'immediate_action', 'prevention']
|
91 |
+
for damage_type, cases in knowledge_base.items():
|
92 |
+
for case in cases:
|
93 |
+
for key in required_keys:
|
94 |
+
if key not in case:
|
95 |
+
raise ValueError(f"Missing required field '{key}' in {damage_type}")
|
96 |
+
logging.info("Knowledge base validation passed.")
|
97 |
+
|
98 |
+
def _get_qa_prompt(self):
|
99 |
+
"""Create a custom prompt template for the QA chain"""
|
100 |
+
template = """
|
101 |
+
Context: {context}
|
102 |
+
|
103 |
+
Question: {question}
|
104 |
+
|
105 |
+
Provide a detailed analysis considering:
|
106 |
+
1. Technical aspects
|
107 |
+
2. Safety implications
|
108 |
+
3. Cost-benefit analysis
|
109 |
+
4. Long-term considerations
|
110 |
+
5. Best practices and recommendations
|
111 |
+
|
112 |
+
Answer:
|
113 |
+
"""
|
114 |
+
return PromptTemplate(
|
115 |
+
template=template,
|
116 |
+
input_variables=["context", "question"]
|
117 |
+
)
|
118 |
+
|
119 |
+
def _generate_expert_insights(self, knowledge_base):
|
120 |
+
"""Generate expert insights for each damage type"""
|
121 |
+
insights = {}
|
122 |
+
for damage_type in knowledge_base.keys():
|
123 |
+
insights[damage_type] = f"Expert analysis for {damage_type} including latest research findings and industry best practices."
|
124 |
+
return insights
|
125 |
+
|
126 |
+
def _generate_case_studies(self):
|
127 |
+
"""Generate relevant case studies for each damage type"""
|
128 |
+
return {
|
129 |
+
"spalling": "Case studies of successful spalling repairs in similar structures",
|
130 |
+
"reinforcement_corrosion": "Examples of corrosion mitigation in harsh environments",
|
131 |
+
"structural_crack": "Analysis of crack progression and successful interventions",
|
132 |
+
"dampness": "Case studies of effective moisture control solutions",
|
133 |
+
"no_damage": "Preventive maintenance success stories"
|
134 |
}
|
|
|
|
|
135 |
|
136 |
+
def get_enhanced_analysis(self, damage_type, confidence, custom_query=None):
|
137 |
+
"""Get enhanced analysis with dynamic content generation"""
|
138 |
+
try:
|
139 |
+
if not custom_query:
|
140 |
+
base_query = f"""
|
141 |
+
Provide a comprehensive analysis for {damage_type} damage with {confidence}% confidence level.
|
142 |
+
Include technical assessment, safety implications, and expert recommendations.
|
143 |
+
"""
|
144 |
+
else:
|
145 |
+
base_query = custom_query
|
146 |
+
|
147 |
+
# Get relevant documents
|
148 |
+
results = self.qa_chain.run(base_query)
|
149 |
+
|
150 |
+
# Process and categorize the response
|
151 |
+
enhanced_info = {
|
152 |
+
"technical_details": self._extract_technical_details(results, damage_type),
|
153 |
+
"safety_considerations": self._extract_safety_considerations(results),
|
154 |
+
"expert_recommendations": self._extract_recommendations(results, confidence)
|
155 |
+
}
|
156 |
+
return enhanced_info
|
157 |
+
except Exception as e:
|
158 |
+
logging.error(f"Failed to generate enhanced analysis: {str(e)}")
|
159 |
+
return None
|
160 |
+
|
161 |
+
def _extract_technical_details(self, results, damage_type):
|
162 |
+
return [f"Detailed technical analysis for {damage_type}", results]
|
163 |
+
|
164 |
+
def _extract_safety_considerations(self, results):
|
165 |
+
return [f"Safety analysis based on current conditions", results]
|
166 |
|
167 |
+
def _extract_recommendations(self, results, confidence):
|
168 |
+
return [f"Prioritized recommendations based on {confidence}% confidence", results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|