Spaces:
Sleeping
Sleeping
Anas Bader
commited on
Commit
·
4cbe4e9
1
Parent(s):
c0d5f87
redo
Browse files- .gitignore +16 -0
- Dockerfile +85 -0
- README.md +4 -4
- app.py +113 -0
- chunking/semantic_chunking.py +93 -0
- elastic/es_client.py +41 -0
- elastic/es_index.py +41 -0
- elastic/indexing.py +60 -0
- elastic/retrieval.py +83 -0
- embeddings/embeddings.py +33 -0
- es_data/node.lock +0 -0
- es_data/nodes +1 -0
- file_processing.py +106 -0
- ingestion.py +0 -0
- prompting/rewrite_question.py +212 -0
- requirements.txt +21 -0
- streamlit.py +50 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.env*
|
2 |
+
myenv*
|
3 |
+
pyproject.toml
|
4 |
+
|
5 |
+
.env*
|
6 |
+
!.env.example
|
7 |
+
myenv*
|
8 |
+
pyproject.toml
|
9 |
+
test.*
|
10 |
+
.conda
|
11 |
+
docs/*
|
12 |
+
__pycache__/
|
13 |
+
.vscode
|
14 |
+
certif_extraction/certifs_md
|
15 |
+
|
16 |
+
*.log
|
Dockerfile
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ubuntu:22.04
|
2 |
+
|
3 |
+
# Install system dependencies
|
4 |
+
RUN apt-get update && apt-get install -y \
|
5 |
+
curl \
|
6 |
+
openjdk-11-jdk \
|
7 |
+
python3 \
|
8 |
+
python3-pip \
|
9 |
+
wget \
|
10 |
+
apt-transport-https \
|
11 |
+
gnupg \
|
12 |
+
&& rm -rf /var/lib/apt/lists/*
|
13 |
+
|
14 |
+
# Install Elasticsearch
|
15 |
+
ENV ES_VERSION=8.8.0
|
16 |
+
RUN curl -O https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-8.8.0-linux-x86_64.tar.gz && \
|
17 |
+
tar -xzf elasticsearch-8.8.0-linux-x86_64.tar.gz && \
|
18 |
+
mv elasticsearch-8.8.0 /usr/share/elasticsearch && \
|
19 |
+
rm elasticsearch-8.8.0-linux-x86_64.tar.gz
|
20 |
+
|
21 |
+
# Create elasticsearch.yml with proper YAML format
|
22 |
+
RUN echo "discovery.type: single-node" > /usr/share/elasticsearch/config/elasticsearch.yml && \
|
23 |
+
echo "xpack.security.enabled: false" >> /usr/share/elasticsearch/config/elasticsearch.yml && \
|
24 |
+
echo "network.host: 0.0.0.0" >> /usr/share/elasticsearch/config/elasticsearch.yml
|
25 |
+
|
26 |
+
# Set Elasticsearch environment variables
|
27 |
+
ENV ES_JAVA_OPTS="-Xms1g -Xmx1g"
|
28 |
+
|
29 |
+
# Create non-root user for running the services
|
30 |
+
RUN useradd -m -u 1000 appuser
|
31 |
+
RUN mkdir -p /app /usr/share/elasticsearch/data && \
|
32 |
+
chown -R appuser:appuser /app /usr/share/elasticsearch
|
33 |
+
|
34 |
+
# Create app directory
|
35 |
+
WORKDIR /app
|
36 |
+
|
37 |
+
# Copy your project files
|
38 |
+
COPY --chown=appuser:appuser app.py streamlit.py requirements.txt ./
|
39 |
+
COPY --chown=appuser:appuser chunking ./chunking
|
40 |
+
COPY --chown=appuser:appuser embeddings ./embeddings
|
41 |
+
COPY --chown=appuser:appuser prompting ./prompting
|
42 |
+
COPY --chown=appuser:appuser elastic ./elastic
|
43 |
+
COPY --chown=appuser:appuser file_processing.py ./
|
44 |
+
COPY --chown=appuser:appuser ingestion.py ./
|
45 |
+
|
46 |
+
# Copy ES data if needed - consider if this is actually necessary
|
47 |
+
COPY --chown=appuser:appuser es_data /usr/share/elasticsearch/data
|
48 |
+
|
49 |
+
# Install Python dependencies
|
50 |
+
RUN pip3 install -r requirements.txt
|
51 |
+
|
52 |
+
# Set environment variables for Streamlit
|
53 |
+
ENV STREAMLIT_SERVER_HEADLESS=true
|
54 |
+
ENV STREAMLIT_SERVER_PORT=7860
|
55 |
+
ENV STREAMLIT_SERVER_ENABLE_CORS=false
|
56 |
+
ENV ES_HOST=localhost
|
57 |
+
ENV ES_PORT=9200
|
58 |
+
ENV ELASTICSEARCH_HOSTS="http://localhost:9200"
|
59 |
+
|
60 |
+
# Expose required ports (Elasticsearch and Streamlit)
|
61 |
+
EXPOSE 9200 7860
|
62 |
+
|
63 |
+
# Switch to non-root user
|
64 |
+
USER appuser
|
65 |
+
|
66 |
+
# Create startup script
|
67 |
+
RUN echo '#!/bin/bash\n\
|
68 |
+
# Start Elasticsearch in the background\n\
|
69 |
+
/usr/share/elasticsearch/bin/elasticsearch &\n\
|
70 |
+
\n\
|
71 |
+
# Wait for Elasticsearch to become available\n\
|
72 |
+
echo "Waiting for Elasticsearch to start..."\n\
|
73 |
+
until curl -s http://localhost:9200 > /dev/null; do\n\
|
74 |
+
sleep 2\n\
|
75 |
+
echo "Still waiting for Elasticsearch..."\n\
|
76 |
+
done\n\
|
77 |
+
echo "Elasticsearch is up and running!"\n\
|
78 |
+
\n\
|
79 |
+
# Start Streamlit\n\
|
80 |
+
echo "Starting Streamlit application..."\n\
|
81 |
+
streamlit run /app/streamlit.py\n\
|
82 |
+
' > /app/start.sh && chmod +x /app/start.sh
|
83 |
+
|
84 |
+
# Command to run
|
85 |
+
CMD ["/app/start.sh"]
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: Rag
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
|
|
1 |
---
|
2 |
+
title: Hydro Rag
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: green
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from typing import Optional
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from embeddings.embeddings import generate_embeddings
|
9 |
+
from elastic.retrieval import search_certification_chunks
|
10 |
+
from prompting.rewrite_question import classify_certification, initialize_llms, process_query
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
app = FastAPI(
|
15 |
+
title="Hydrogen Certification RAG System",
|
16 |
+
description="API for querying hydrogen certification documents using RAG",
|
17 |
+
version="0.1.0"
|
18 |
+
)
|
19 |
+
|
20 |
+
# Initialize LLMs and Elasticsearch client
|
21 |
+
llms = initialize_llms()
|
22 |
+
|
23 |
+
# Request models
|
24 |
+
class QueryRequest(BaseModel):
|
25 |
+
query: str
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
llm = initialize_llms()["rewrite_llm"]
|
30 |
+
|
31 |
+
|
32 |
+
# Endpoints
|
33 |
+
@app.post("/query")
|
34 |
+
async def handle_query(request: QueryRequest):
|
35 |
+
"""
|
36 |
+
Process a query through the full RAG pipeline:
|
37 |
+
1. Classify certification (if not provided)
|
38 |
+
2. Optimize query based on specificity
|
39 |
+
3. Search relevant chunks
|
40 |
+
"""
|
41 |
+
try:
|
42 |
+
# Step 1: Determine certification
|
43 |
+
query = request.query
|
44 |
+
certification = classify_certification(request.query, llms["rewrite_llm"])
|
45 |
+
if "no certification mentioned" in certification :
|
46 |
+
raise HTTPException(
|
47 |
+
status_code=400,
|
48 |
+
detail="No certification specified in query and none provided"
|
49 |
+
)
|
50 |
+
|
51 |
+
# Step 2: Process query
|
52 |
+
processed_query = process_query(request.query, llms)
|
53 |
+
question_vector = generate_embeddings(processed_query)
|
54 |
+
|
55 |
+
# Step 3: Search
|
56 |
+
results = search_certification_chunks(
|
57 |
+
index_name="certif_index",
|
58 |
+
certification_name=certification,
|
59 |
+
text_query=processed_query,
|
60 |
+
)
|
61 |
+
|
62 |
+
results_ = search_certification_chunks(
|
63 |
+
index_name="certification_index",
|
64 |
+
certification_name=certification,
|
65 |
+
text_query=processed_query,
|
66 |
+
vector_query=question_vector,
|
67 |
+
)
|
68 |
+
|
69 |
+
results_merged = ". ".join([result["text"] for result in results])
|
70 |
+
results_merged_ = ". ".join([result["text"] for result in results_])
|
71 |
+
|
72 |
+
template = """
|
73 |
+
You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification.
|
74 |
+
|
75 |
+
Provide a clear, concise response that directly addresses the question without unnecessary information.
|
76 |
+
|
77 |
+
Question: {question}
|
78 |
+
Certification: {certification}
|
79 |
+
Context: {context}
|
80 |
+
|
81 |
+
Answer:
|
82 |
+
"""
|
83 |
+
prompt = PromptTemplate(
|
84 |
+
input_variables=["question", "certification", "context"],
|
85 |
+
template=template
|
86 |
+
)
|
87 |
+
|
88 |
+
chain = prompt | llm
|
89 |
+
answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content
|
90 |
+
answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content
|
91 |
+
|
92 |
+
|
93 |
+
return {
|
94 |
+
"certification": certification,
|
95 |
+
"certif_index": answer,
|
96 |
+
"certification_index": answer_,
|
97 |
+
}
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
raise HTTPException(status_code=500, detail=str(e))
|
101 |
+
|
102 |
+
@app.get("/certifications", response_model=list[str])
|
103 |
+
async def list_certifications():
|
104 |
+
"""List all available certifications"""
|
105 |
+
try:
|
106 |
+
certs_dir = "docs/processed"
|
107 |
+
return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))]
|
108 |
+
except Exception as e:
|
109 |
+
raise HTTPException(status_code=500, detail=str(e))
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
import uvicorn
|
113 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
chunking/semantic_chunking.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
|
4 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
5 |
+
|
6 |
+
|
7 |
+
def hybrid_split(text: str, max_len: int = 1024) -> list[str]:
|
8 |
+
"""
|
9 |
+
Split text into chunks respecting sentence boundaries when possible,
|
10 |
+
with optional overlap between chunks.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
text: The text to split
|
14 |
+
max_len: Maximum length for each chunk
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
List of text chunks
|
18 |
+
"""
|
19 |
+
# Normalize text
|
20 |
+
text = text.replace("\r", "").replace("\n", " ").strip()
|
21 |
+
|
22 |
+
# Extract sentences (more robust regex for sentence detection)
|
23 |
+
import re
|
24 |
+
|
25 |
+
sentences = re.split(r"(?<=[.!?])\s+", text)
|
26 |
+
|
27 |
+
chunks = []
|
28 |
+
current_chunk = ""
|
29 |
+
|
30 |
+
for sentence in sentences:
|
31 |
+
if len(sentence) > max_len:
|
32 |
+
# First add the current chunk if it exists
|
33 |
+
chunks.append(sentence)
|
34 |
+
|
35 |
+
# Normal case - see if adding the sentence exceeds max_len
|
36 |
+
elif len(current_chunk) + len(sentence) + 1 > max_len:
|
37 |
+
# Add the current chunk and start a new one
|
38 |
+
chunks.append(current_chunk)
|
39 |
+
current_chunk = ""
|
40 |
+
else:
|
41 |
+
# Add to the current chunk
|
42 |
+
if current_chunk:
|
43 |
+
current_chunk += " " + sentence
|
44 |
+
else:
|
45 |
+
current_chunk = sentence
|
46 |
+
|
47 |
+
if current_chunk:
|
48 |
+
chunks.append(current_chunk)
|
49 |
+
|
50 |
+
return chunks
|
51 |
+
|
52 |
+
|
53 |
+
def cosine_similarity(vec1, vec2):
|
54 |
+
"""Calculate the cosine similarity between two vectors."""
|
55 |
+
dot_product = np.dot(vec1, vec2)
|
56 |
+
norm_vec1 = np.linalg.norm(vec1)
|
57 |
+
norm_vec2 = np.linalg.norm(vec2)
|
58 |
+
return dot_product / (norm_vec1 * norm_vec2)
|
59 |
+
|
60 |
+
|
61 |
+
def get_embedding(text):
|
62 |
+
"""Generate an embedding using SBERT."""
|
63 |
+
return embedding_model.encode(text, convert_to_numpy=True)
|
64 |
+
|
65 |
+
|
66 |
+
def semantic_chunking(text, threshold=0.75, max_chunk_size=8191):
|
67 |
+
"""
|
68 |
+
Splits text into semantic chunks based on sentence similarity.
|
69 |
+
- threshold: Lower = more splits, Higher = fewer splits
|
70 |
+
- max_chunk_size: Maximum size of each chunk in characters
|
71 |
+
"""
|
72 |
+
text = text.replace("\n", " ").replace("\r", " ").strip()
|
73 |
+
sentences = hybrid_split(text)
|
74 |
+
embeddings = [get_embedding(sent) for sent in sentences]
|
75 |
+
|
76 |
+
chunks = []
|
77 |
+
current_chunk = [sentences[0]]
|
78 |
+
|
79 |
+
for i in range(1, len(sentences)):
|
80 |
+
sim = cosine_similarity(embeddings[i - 1], embeddings[i])
|
81 |
+
if (
|
82 |
+
sim < threshold
|
83 |
+
or len(" ".join(current_chunk + [sentences[i]])) > max_chunk_size
|
84 |
+
):
|
85 |
+
chunks.append(" ".join(current_chunk))
|
86 |
+
current_chunk = [sentences[i]]
|
87 |
+
else:
|
88 |
+
current_chunk.append(sentences[i])
|
89 |
+
|
90 |
+
if current_chunk:
|
91 |
+
chunks.append(" ".join(current_chunk))
|
92 |
+
|
93 |
+
return chunks
|
elastic/es_client.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from elasticsearch import Elasticsearch, ConnectionError, AuthenticationException
|
4 |
+
|
5 |
+
# Configure logging at the application level
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
ES_CLIENT_URL = os.getenv("ELASTICSEARCH_HOSTS", "http://localhost:9200")
|
11 |
+
|
12 |
+
class ElasticsearchClientError(Exception):
|
13 |
+
"""Custom exception for Elasticsearch client errors."""
|
14 |
+
pass
|
15 |
+
|
16 |
+
def get_es_client() -> Elasticsearch:
|
17 |
+
"""
|
18 |
+
Establish connection to Elasticsearch and return the client instance.
|
19 |
+
Raises ElasticsearchClientError if the connection cannot be established.
|
20 |
+
"""
|
21 |
+
|
22 |
+
try:
|
23 |
+
print("es client", ES_CLIENT_URL)
|
24 |
+
# Initialize Elasticsearch client
|
25 |
+
es_client = Elasticsearch(
|
26 |
+
hosts=[ES_CLIENT_URL],
|
27 |
+
)
|
28 |
+
|
29 |
+
# Verify connection
|
30 |
+
if not es_client.ping():
|
31 |
+
error_message = "Elasticsearch cluster is not reachable!"
|
32 |
+
logger.error(error_message)
|
33 |
+
raise ElasticsearchClientError(error_message)
|
34 |
+
|
35 |
+
logger.info("Successfully connected to Elasticsearch")
|
36 |
+
return es_client
|
37 |
+
|
38 |
+
except (ConnectionError, AuthenticationException) as e:
|
39 |
+
error_message = f"Elasticsearch connection error: {e}"
|
40 |
+
logger.error(error_message)
|
41 |
+
raise ElasticsearchClientError(error_message) from e
|
elastic/es_index.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from elasticsearch import Elasticsearch, ConnectionError, AuthenticationException
|
4 |
+
|
5 |
+
# Configure logging at the application level
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
|
9 |
+
# Load environment variables
|
10 |
+
ES_CLIENT_URL = os.getenv("ELASTICSEARCH_HOSTS")
|
11 |
+
|
12 |
+
class ElasticsearchClientError(Exception):
|
13 |
+
"""Custom exception for Elasticsearch client errors."""
|
14 |
+
pass
|
15 |
+
|
16 |
+
def get_es_client() -> Elasticsearch:
|
17 |
+
"""
|
18 |
+
Establish connection to Elasticsearch and return the client instance.
|
19 |
+
Raises ElasticsearchClientError if the connection cannot be established.
|
20 |
+
"""
|
21 |
+
|
22 |
+
try:
|
23 |
+
print("es client", ES_CLIENT_URL)
|
24 |
+
# Initialize Elasticsearch client
|
25 |
+
es_client = Elasticsearch(
|
26 |
+
hosts=[ES_CLIENT_URL],
|
27 |
+
)
|
28 |
+
|
29 |
+
# Verify connection
|
30 |
+
if not es_client.ping():
|
31 |
+
error_message = "Elasticsearch cluster is not reachable!"
|
32 |
+
logger.error(error_message)
|
33 |
+
raise ElasticsearchClientError(error_message)
|
34 |
+
|
35 |
+
logger.info("Successfully connected to Elasticsearch")
|
36 |
+
return es_client
|
37 |
+
|
38 |
+
except (ConnectionError, AuthenticationException) as e:
|
39 |
+
error_message = f"Elasticsearch connection error: {e}"
|
40 |
+
logger.error(error_message)
|
41 |
+
raise ElasticsearchClientError(error_message) from e
|
elastic/indexing.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from elasticsearch import Elasticsearch, exceptions
|
3 |
+
from typing import Dict, Any
|
4 |
+
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
logging.basicConfig(level=logging.DEBUG)
|
8 |
+
|
9 |
+
embedding_dimension = 1536
|
10 |
+
|
11 |
+
def create_mapping(properties: Dict[str, Any]) -> Dict[str, Any]:
|
12 |
+
"""Helper function to create index mappings with predefined settings."""
|
13 |
+
return {
|
14 |
+
"settings": {"number_of_shards": 1, "number_of_replicas": 1},
|
15 |
+
"mappings": {"properties": properties},
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def retrieval_index() -> Dict[str, Any]:
|
20 |
+
"""Returns the Elasticsearch mapping for retrieval indices."""
|
21 |
+
return create_mapping(
|
22 |
+
{
|
23 |
+
"chunk_id": {"type": "keyword"},
|
24 |
+
"chunk": {"type": "text"},
|
25 |
+
"embedding": {
|
26 |
+
"type": "dense_vector",
|
27 |
+
"dims": embedding_dimension,
|
28 |
+
},
|
29 |
+
"certification": {"type": "keyword"},
|
30 |
+
"source_file": {"type": "keyword"},
|
31 |
+
"timestamp": {"type": "date"},
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def create_elasticsearch_index(es_client: Elasticsearch, index_name: str) -> bool:
|
37 |
+
"""
|
38 |
+
Create an Elasticsearch index with the appropriate mapping.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
es_client (Elasticsearch): The Elasticsearch client instance.
|
42 |
+
index_name (str): The name of the index to create.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
bool: True if the index was created successfully, False otherwise.
|
46 |
+
"""
|
47 |
+
try:
|
48 |
+
mapping = retrieval_index()
|
49 |
+
|
50 |
+
if es_client.indices.exists(index=index_name):
|
51 |
+
logger.warning(f"Index '{index_name}' already exists. Skipping creation.")
|
52 |
+
return True
|
53 |
+
|
54 |
+
es_client.indices.create(index=index_name, body=mapping)
|
55 |
+
logger.info(f"Index '{index_name}' created successfully.")
|
56 |
+
return True
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"Unexpected error while creating index '{index_name}': {e}")
|
60 |
+
return False
|
elastic/retrieval.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Any, Optional
|
2 |
+
import logging
|
3 |
+
from elasticsearch import exceptions
|
4 |
+
|
5 |
+
from elastic.es_client import get_es_client
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
es_client = get_es_client()
|
10 |
+
|
11 |
+
|
12 |
+
def search_certification_chunks(
|
13 |
+
index_name: str,
|
14 |
+
text_query: str,
|
15 |
+
vector_query: List[float],
|
16 |
+
certification_name: str,
|
17 |
+
es_client=es_client,
|
18 |
+
vector_field: str = "embedding",
|
19 |
+
text_field: str = "chunk",
|
20 |
+
size: int = 5,
|
21 |
+
min_score: float = 0.1, # Lowered threshold
|
22 |
+
boost_text: float = 1.0,
|
23 |
+
boost_vector: float = 1.0,
|
24 |
+
) -> List[Dict[str, Any]]:
|
25 |
+
|
26 |
+
# First verify the certification value exists
|
27 |
+
cert_check = es_client.search(
|
28 |
+
index=index_name,
|
29 |
+
body={
|
30 |
+
"query": {"term": {"certification": certification_name}},
|
31 |
+
"size": 1,
|
32 |
+
},
|
33 |
+
)
|
34 |
+
|
35 |
+
if not cert_check["hits"]["hits"]:
|
36 |
+
logger.error(f"No documents found with certification: {certification_name}")
|
37 |
+
return []
|
38 |
+
|
39 |
+
# Then proceed with hybrid search
|
40 |
+
query_body = {
|
41 |
+
"size": size,
|
42 |
+
"query": {
|
43 |
+
"bool": {
|
44 |
+
"should": [
|
45 |
+
{"match": {"chunk": text_query}},
|
46 |
+
{
|
47 |
+
"script_score": {
|
48 |
+
"query": {"match_all": {}},
|
49 |
+
"script": {
|
50 |
+
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
|
51 |
+
"params": {"query_vector": vector_query},
|
52 |
+
},
|
53 |
+
}
|
54 |
+
},
|
55 |
+
]
|
56 |
+
}
|
57 |
+
},
|
58 |
+
}
|
59 |
+
logger.debug(f"Elasticsearch query body: {query_body}")
|
60 |
+
|
61 |
+
logger.info(f"Executing search on index '{index_name}'")
|
62 |
+
response = es_client.search(index=index_name, body=query_body, routing=cert_check["hits"]["hits"][0]["_id"])
|
63 |
+
hits = response.get("hits", {}).get("hits", [])
|
64 |
+
logger.info(f"Found {len(hits)} matching documents")
|
65 |
+
|
66 |
+
# Process results with correct field names
|
67 |
+
results = [
|
68 |
+
{
|
69 |
+
"id": hit["_id"],
|
70 |
+
"score": hit["_score"],
|
71 |
+
"text": hit["_source"]["chunk"],
|
72 |
+
"source_file": hit["_source"]["source_file"],
|
73 |
+
}
|
74 |
+
for hit in hits
|
75 |
+
]
|
76 |
+
|
77 |
+
if results:
|
78 |
+
logger.debug(f"Top result score: {results[0]['score']}")
|
79 |
+
logger.debug(f"Top result source: {results[0]['source_file']}")
|
80 |
+
else:
|
81 |
+
logger.warning("No results returned from Elasticsearch")
|
82 |
+
|
83 |
+
return results
|
embeddings/embeddings.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from typing import List
|
4 |
+
from openai import OpenAI
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
# Configure logging at the application level
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
|
11 |
+
load_dotenv()
|
12 |
+
|
13 |
+
embedding_dimension = 1536
|
14 |
+
|
15 |
+
model = "text-embedding-3-small"
|
16 |
+
|
17 |
+
openai_api_key = os.getenv("OPENAI_API_KEY").strip().strip("\n")
|
18 |
+
client = OpenAI(api_key=openai_api_key)
|
19 |
+
|
20 |
+
|
21 |
+
def generate_embeddings(text: str) -> List[float]:
|
22 |
+
"""Get embeddings from OpenAI API."""
|
23 |
+
logging.info("Embedding model: %s", model)
|
24 |
+
|
25 |
+
try:
|
26 |
+
if text:
|
27 |
+
response = client.embeddings.create(
|
28 |
+
model=model, input=text, dimensions=embedding_dimension
|
29 |
+
)
|
30 |
+
return response.data[0].embedding
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"OpenAI API error: {e}")
|
33 |
+
raise
|
es_data/node.lock
ADDED
File without changes
|
es_data/nodes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
written by Elasticsearch v8.8.0 to prevent a downgrade to a version prior to v8.0.0 which would result in data loss
|
file_processing.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdfplumber
|
2 |
+
from docx import Document
|
3 |
+
from openpyxl import load_workbook
|
4 |
+
import pdfplumber
|
5 |
+
import logging
|
6 |
+
from typing import List, Union, Tuple
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
# Set up logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def extract_pdf_content(pdf_path: str) -> List[str]:
|
16 |
+
"""
|
17 |
+
Extract text and tables from PDF in their natural reading order.
|
18 |
+
Simplified version without positional processing.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
pdf_path (str): Path to the PDF file
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
List[str]: List of extracted content chunks (text and tables)
|
25 |
+
"""
|
26 |
+
if not os.path.exists(pdf_path):
|
27 |
+
logger.error(f"PDF file not found: {pdf_path}")
|
28 |
+
return []
|
29 |
+
|
30 |
+
try:
|
31 |
+
with pdfplumber.open(pdf_path) as pdf:
|
32 |
+
content = []
|
33 |
+
|
34 |
+
for page in pdf.pages:
|
35 |
+
# First extract tables
|
36 |
+
tables = page.extract_tables()
|
37 |
+
for table in tables:
|
38 |
+
if table:
|
39 |
+
# Convert table to string representation
|
40 |
+
table_str = "\n".join(
|
41 |
+
["\t".join(str(cell) for cell in row) for row in table]
|
42 |
+
)
|
43 |
+
content.append(f"[TABLE]\n{table_str}\n[/TABLE]")
|
44 |
+
|
45 |
+
# Then extract regular text
|
46 |
+
text = page.extract_text()
|
47 |
+
if text and text.strip():
|
48 |
+
content.append(text.strip())
|
49 |
+
|
50 |
+
logger.info(f"Successfully extracted content from {pdf_path}")
|
51 |
+
return content
|
52 |
+
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Error processing {pdf_path}: {str(e)}")
|
55 |
+
return []
|
56 |
+
|
57 |
+
|
58 |
+
from docx import Document
|
59 |
+
from typing import List
|
60 |
+
import os
|
61 |
+
|
62 |
+
def extract_docx_content(docx_path: str) -> List[str]:
|
63 |
+
"""
|
64 |
+
Extract text and tables from DOCX file with clear table markers.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
docx_path (str): Path to the DOCX file
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
List[str]: List of extracted content chunks with tables marked as [TABLE]...[/TABLE]
|
71 |
+
"""
|
72 |
+
if not os.path.exists(docx_path):
|
73 |
+
raise FileNotFoundError(f"DOCX file not found: {docx_path}")
|
74 |
+
|
75 |
+
doc = Document(docx_path)
|
76 |
+
content = []
|
77 |
+
|
78 |
+
# Process all paragraphs first
|
79 |
+
for paragraph in doc.paragraphs:
|
80 |
+
text = paragraph.text.strip()
|
81 |
+
if text:
|
82 |
+
content.append(text)
|
83 |
+
|
84 |
+
# Process all tables after paragraphs
|
85 |
+
for table in doc.tables:
|
86 |
+
table_str = "\n".join(
|
87 |
+
["\t".join(cell.text.strip() for cell in row.cells)
|
88 |
+
for row in table.rows]
|
89 |
+
)
|
90 |
+
if table_str.strip():
|
91 |
+
content.append(f"[TABLE]\n{table_str}\n[/TABLE]")
|
92 |
+
|
93 |
+
return content
|
94 |
+
|
95 |
+
def extract_xlsx_content(file_path: str):
|
96 |
+
wb = load_workbook(file_path)
|
97 |
+
sheets_text = []
|
98 |
+
|
99 |
+
for sheet in wb:
|
100 |
+
sheet_str = f"--- Sheet: {sheet.title} ---\n"
|
101 |
+
for row in sheet.iter_rows():
|
102 |
+
row_str = "\t".join(str(cell.value) if cell.value else "" for cell in row)
|
103 |
+
sheet_str += row_str + "\n"
|
104 |
+
sheets_text.append(sheet_str.strip())
|
105 |
+
|
106 |
+
return sheets_text
|
ingestion.py
ADDED
File without changes
|
prompting/rewrite_question.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from langchain_groq import ChatGroq
|
5 |
+
from typing import Literal
|
6 |
+
|
7 |
+
# Load environment variables
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
# Initialize LLMs
|
11 |
+
def initialize_llms():
|
12 |
+
"""Initialize and return the LLM instances"""
|
13 |
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
14 |
+
|
15 |
+
return {
|
16 |
+
"rewrite_llm": ChatGroq(
|
17 |
+
temperature=0.1,
|
18 |
+
model="llama-3.3-70b-versatile",
|
19 |
+
api_key=groq_api_key
|
20 |
+
),
|
21 |
+
"step_back_llm": ChatGroq(
|
22 |
+
temperature=0,
|
23 |
+
model="Gemma2-9B-IT",
|
24 |
+
api_key=groq_api_key
|
25 |
+
)
|
26 |
+
}
|
27 |
+
|
28 |
+
# Certification classification
|
29 |
+
def classify_certification(
|
30 |
+
query: str,
|
31 |
+
llm: ChatGroq,
|
32 |
+
certs_dir: str = "docs/processed"
|
33 |
+
) -> str:
|
34 |
+
"""
|
35 |
+
Classify which certification a query is referring to.
|
36 |
+
Returns certification name or 'no certification mentioned'.
|
37 |
+
"""
|
38 |
+
available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation"
|
39 |
+
|
40 |
+
template = """
|
41 |
+
You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system.
|
42 |
+
Classify the given query into one of the following certifications:
|
43 |
+
- {available_certifications}
|
44 |
+
|
45 |
+
Don't need any explanation, just return the name of the certification.
|
46 |
+
|
47 |
+
Use the exact name of the certification as it appears in the directory.
|
48 |
+
If the query refers to multiple certifications, return the most relevant one.
|
49 |
+
|
50 |
+
If the query doesn't mention any certification, respond with "no certification mentioned".
|
51 |
+
|
52 |
+
Original query: {original_query}
|
53 |
+
|
54 |
+
Classification:
|
55 |
+
"""
|
56 |
+
|
57 |
+
prompt = PromptTemplate(
|
58 |
+
input_variables=["original_query", "available_certifications"],
|
59 |
+
template=template
|
60 |
+
)
|
61 |
+
|
62 |
+
chain = prompt | llm
|
63 |
+
response = chain.invoke({
|
64 |
+
"original_query": query,
|
65 |
+
"available_certifications": available_certs
|
66 |
+
}).content.strip()
|
67 |
+
|
68 |
+
return response
|
69 |
+
|
70 |
+
# Query specificity classification
|
71 |
+
def classify_query_specificity(
|
72 |
+
query: str,
|
73 |
+
llm: ChatGroq
|
74 |
+
) -> Literal["specific", "general", "too narrow"]:
|
75 |
+
"""
|
76 |
+
Classify query specificity.
|
77 |
+
Returns one of: 'specific', 'general', or 'too narrow'.
|
78 |
+
"""
|
79 |
+
template = """
|
80 |
+
You are an AI assistant classifying user queries based on their specificity for a RAG system.
|
81 |
+
Classify the given query into one of:
|
82 |
+
- "specific" → If it asks for exact values, certifications, or well-defined facts.
|
83 |
+
- "general" → If it is broad and needs refinement for better retrieval.
|
84 |
+
- "too narrow" → If it is very specific and might need broader context.
|
85 |
+
DO NOT output explanations, only return one of: "specific", "general", or "too narrow".
|
86 |
+
|
87 |
+
Original query: {original_query}
|
88 |
+
|
89 |
+
Classification:
|
90 |
+
"""
|
91 |
+
|
92 |
+
prompt = PromptTemplate(
|
93 |
+
input_variables=["original_query"],
|
94 |
+
template=template
|
95 |
+
)
|
96 |
+
|
97 |
+
chain = prompt | llm
|
98 |
+
response = chain.invoke({"original_query": query}).content.strip().lower()
|
99 |
+
return response.split("\n")[0].strip() # type: ignore
|
100 |
+
|
101 |
+
# Query refinement
|
102 |
+
def refine_query(
|
103 |
+
query: str,
|
104 |
+
llm: ChatGroq
|
105 |
+
) -> str:
|
106 |
+
"""Rewrite a query to be clearer and more detailed while keeping the original intent"""
|
107 |
+
template = """
|
108 |
+
You are an AI assistant that improves queries for retrieving precise certification and compliance data.
|
109 |
+
Rewrite the query to be clearer while keeping the intent unchanged.
|
110 |
+
|
111 |
+
Original query: {original_query}
|
112 |
+
|
113 |
+
Refined query:
|
114 |
+
"""
|
115 |
+
|
116 |
+
prompt = PromptTemplate(
|
117 |
+
input_variables=["original_query"],
|
118 |
+
template=template
|
119 |
+
)
|
120 |
+
|
121 |
+
chain = prompt | llm
|
122 |
+
return chain.invoke({"original_query": query}).content
|
123 |
+
|
124 |
+
# Step-back query generation
|
125 |
+
def generate_step_back_query(
|
126 |
+
query: str,
|
127 |
+
llm: ChatGroq
|
128 |
+
) -> str:
|
129 |
+
"""Generate a broader step-back query to retrieve relevant background information"""
|
130 |
+
template = """
|
131 |
+
You are an AI assistant generating broader queries to improve retrieval context.
|
132 |
+
Given the original query, generate a more general step-back query to retrieve relevant background information.
|
133 |
+
|
134 |
+
Original query: {original_query}
|
135 |
+
|
136 |
+
Step-back query:
|
137 |
+
"""
|
138 |
+
|
139 |
+
prompt = PromptTemplate(
|
140 |
+
input_variables=["original_query"],
|
141 |
+
template=template
|
142 |
+
)
|
143 |
+
|
144 |
+
chain = prompt | llm
|
145 |
+
return chain.invoke({"original_query": query}).content
|
146 |
+
|
147 |
+
# Main query processing pipeline
|
148 |
+
def process_query(
|
149 |
+
original_query: str,
|
150 |
+
llms: dict
|
151 |
+
) -> str:
|
152 |
+
"""
|
153 |
+
Process a query through the full pipeline:
|
154 |
+
1. Classify specificity
|
155 |
+
2. Apply appropriate refinement
|
156 |
+
"""
|
157 |
+
specificity = classify_query_specificity(original_query, llms["rewrite_llm"])
|
158 |
+
|
159 |
+
if specificity == "specific":
|
160 |
+
return refine_query(original_query, llms["rewrite_llm"])
|
161 |
+
elif specificity == "general":
|
162 |
+
return refine_query(original_query, llms["rewrite_llm"])
|
163 |
+
elif specificity == "too narrow":
|
164 |
+
return generate_step_back_query(original_query, llms["step_back_llm"])
|
165 |
+
return original_query
|
166 |
+
|
167 |
+
# Test setup
|
168 |
+
def test_hydrogen_certification_functions():
|
169 |
+
# Initialize LLMs
|
170 |
+
llms = initialize_llms()
|
171 |
+
|
172 |
+
# Create a test directory with hydrogen certifications
|
173 |
+
test_certs_dir = "docs/processed"
|
174 |
+
os.makedirs(test_certs_dir, exist_ok=True)
|
175 |
+
|
176 |
+
# Create some dummy certification folders
|
177 |
+
hydrogen_certifications = [
|
178 |
+
"GH2_Standard",
|
179 |
+
"Certified_Hydrogen_Producer",
|
180 |
+
"Green_Hydrogen_Certification",
|
181 |
+
"ISO_19880_Hydrogen_Quality"
|
182 |
+
]
|
183 |
+
|
184 |
+
for cert in hydrogen_certifications:
|
185 |
+
os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True)
|
186 |
+
|
187 |
+
# Test queries
|
188 |
+
test_queries = [
|
189 |
+
("What are the purity requirements in GH2 Standard?", "specific"),
|
190 |
+
("How does hydrogen certification work?", "general"),
|
191 |
+
("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"),
|
192 |
+
("What safety protocols exist for hydrogen storage?", "general")
|
193 |
+
]
|
194 |
+
|
195 |
+
print("=== Testing Certification Classification ===")
|
196 |
+
for query, _ in test_queries:
|
197 |
+
cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir)
|
198 |
+
print(f"Query: {query}\nClassification: {cert}\n")
|
199 |
+
|
200 |
+
print("\n=== Testing Specificity Classification ===")
|
201 |
+
for query, expected_type in test_queries:
|
202 |
+
specificity = classify_query_specificity(query, llms["rewrite_llm"])
|
203 |
+
print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n")
|
204 |
+
|
205 |
+
print("\n=== Testing Full Query Processing ===")
|
206 |
+
for query, _ in test_queries:
|
207 |
+
processed = process_query(query, llms)
|
208 |
+
print(f"Original: {query}\nProcessed: {processed}\n")
|
209 |
+
|
210 |
+
# Run the tests
|
211 |
+
if __name__ == "__main__":
|
212 |
+
test_hydrogen_certification_functions()
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-groq
|
3 |
+
langchain-community
|
4 |
+
chromadb
|
5 |
+
jq
|
6 |
+
fastembed
|
7 |
+
python-dotenv
|
8 |
+
langchain_chroma
|
9 |
+
unstructured
|
10 |
+
openai
|
11 |
+
elastic-transport==8.17.0
|
12 |
+
elasticsearch==8.17.1
|
13 |
+
sentence-transformers
|
14 |
+
fastapi
|
15 |
+
pdfplumber
|
16 |
+
pdfminer.six
|
17 |
+
python-docx
|
18 |
+
openpyxl
|
19 |
+
PyPDF2
|
20 |
+
streamlit
|
21 |
+
uvicorn
|
streamlit.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import logging
|
3 |
+
import asyncio
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from app import QueryRequest # Import the request model
|
6 |
+
|
7 |
+
# Configure logging
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# Set page config
|
12 |
+
st.set_page_config(page_title="Certification Chat", layout="centered")
|
13 |
+
|
14 |
+
st.title("🎓 Certification Chat Assistant")
|
15 |
+
|
16 |
+
# Create a function to handle the async call
|
17 |
+
async def async_query(query_text):
|
18 |
+
from app import handle_query # Import here to avoid circular imports
|
19 |
+
request = QueryRequest(query=query_text)
|
20 |
+
return await handle_query(request)
|
21 |
+
|
22 |
+
# Function to run async code in Streamlit
|
23 |
+
def run_async(coroutine):
|
24 |
+
try:
|
25 |
+
loop = asyncio.get_event_loop()
|
26 |
+
except RuntimeError:
|
27 |
+
loop = asyncio.new_event_loop()
|
28 |
+
asyncio.set_event_loop(loop)
|
29 |
+
return loop.run_until_complete(coroutine)
|
30 |
+
|
31 |
+
# User input
|
32 |
+
user_input = st.text_input("💬 Enter your prompt:")
|
33 |
+
|
34 |
+
if user_input:
|
35 |
+
st.markdown("## 🧠 Response")
|
36 |
+
|
37 |
+
try:
|
38 |
+
# Use try-except to handle errors
|
39 |
+
with st.spinner("Processing your query..."):
|
40 |
+
# Run the async function
|
41 |
+
result = run_async(async_query(user_input))
|
42 |
+
|
43 |
+
# Display output
|
44 |
+
st.write("**Certification:**", result["certification"])
|
45 |
+
st.write("**Answer from certif_index:**", result["certif_index"])
|
46 |
+
st.write("**Answer from certification_index:**", result["certification_index"])
|
47 |
+
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"An error occurred: {str(e)}")
|
50 |
+
logger.error(f"Error processing query: {e}", exc_info=True)
|