Anas Bader commited on
Commit
4cbe4e9
·
1 Parent(s): c0d5f87
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env*
2
+ myenv*
3
+ pyproject.toml
4
+
5
+ .env*
6
+ !.env.example
7
+ myenv*
8
+ pyproject.toml
9
+ test.*
10
+ .conda
11
+ docs/*
12
+ __pycache__/
13
+ .vscode
14
+ certif_extraction/certifs_md
15
+
16
+ *.log
Dockerfile ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ curl \
6
+ openjdk-11-jdk \
7
+ python3 \
8
+ python3-pip \
9
+ wget \
10
+ apt-transport-https \
11
+ gnupg \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Install Elasticsearch
15
+ ENV ES_VERSION=8.8.0
16
+ RUN curl -O https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-8.8.0-linux-x86_64.tar.gz && \
17
+ tar -xzf elasticsearch-8.8.0-linux-x86_64.tar.gz && \
18
+ mv elasticsearch-8.8.0 /usr/share/elasticsearch && \
19
+ rm elasticsearch-8.8.0-linux-x86_64.tar.gz
20
+
21
+ # Create elasticsearch.yml with proper YAML format
22
+ RUN echo "discovery.type: single-node" > /usr/share/elasticsearch/config/elasticsearch.yml && \
23
+ echo "xpack.security.enabled: false" >> /usr/share/elasticsearch/config/elasticsearch.yml && \
24
+ echo "network.host: 0.0.0.0" >> /usr/share/elasticsearch/config/elasticsearch.yml
25
+
26
+ # Set Elasticsearch environment variables
27
+ ENV ES_JAVA_OPTS="-Xms1g -Xmx1g"
28
+
29
+ # Create non-root user for running the services
30
+ RUN useradd -m -u 1000 appuser
31
+ RUN mkdir -p /app /usr/share/elasticsearch/data && \
32
+ chown -R appuser:appuser /app /usr/share/elasticsearch
33
+
34
+ # Create app directory
35
+ WORKDIR /app
36
+
37
+ # Copy your project files
38
+ COPY --chown=appuser:appuser app.py streamlit.py requirements.txt ./
39
+ COPY --chown=appuser:appuser chunking ./chunking
40
+ COPY --chown=appuser:appuser embeddings ./embeddings
41
+ COPY --chown=appuser:appuser prompting ./prompting
42
+ COPY --chown=appuser:appuser elastic ./elastic
43
+ COPY --chown=appuser:appuser file_processing.py ./
44
+ COPY --chown=appuser:appuser ingestion.py ./
45
+
46
+ # Copy ES data if needed - consider if this is actually necessary
47
+ COPY --chown=appuser:appuser es_data /usr/share/elasticsearch/data
48
+
49
+ # Install Python dependencies
50
+ RUN pip3 install -r requirements.txt
51
+
52
+ # Set environment variables for Streamlit
53
+ ENV STREAMLIT_SERVER_HEADLESS=true
54
+ ENV STREAMLIT_SERVER_PORT=7860
55
+ ENV STREAMLIT_SERVER_ENABLE_CORS=false
56
+ ENV ES_HOST=localhost
57
+ ENV ES_PORT=9200
58
+ ENV ELASTICSEARCH_HOSTS="http://localhost:9200"
59
+
60
+ # Expose required ports (Elasticsearch and Streamlit)
61
+ EXPOSE 9200 7860
62
+
63
+ # Switch to non-root user
64
+ USER appuser
65
+
66
+ # Create startup script
67
+ RUN echo '#!/bin/bash\n\
68
+ # Start Elasticsearch in the background\n\
69
+ /usr/share/elasticsearch/bin/elasticsearch &\n\
70
+ \n\
71
+ # Wait for Elasticsearch to become available\n\
72
+ echo "Waiting for Elasticsearch to start..."\n\
73
+ until curl -s http://localhost:9200 > /dev/null; do\n\
74
+ sleep 2\n\
75
+ echo "Still waiting for Elasticsearch..."\n\
76
+ done\n\
77
+ echo "Elasticsearch is up and running!"\n\
78
+ \n\
79
+ # Start Streamlit\n\
80
+ echo "Starting Streamlit application..."\n\
81
+ streamlit run /app/streamlit.py\n\
82
+ ' > /app/start.sh && chmod +x /app/start.sh
83
+
84
+ # Command to run
85
+ CMD ["/app/start.sh"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Rag Hydro
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: Hydro Rag
3
+ emoji: 🐢
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: docker
7
  pinned: false
8
  ---
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, HTTPException
3
+ from langchain.prompts import PromptTemplate
4
+ from pydantic import BaseModel
5
+ from typing import Optional
6
+ from dotenv import load_dotenv
7
+
8
+ from embeddings.embeddings import generate_embeddings
9
+ from elastic.retrieval import search_certification_chunks
10
+ from prompting.rewrite_question import classify_certification, initialize_llms, process_query
11
+
12
+ load_dotenv()
13
+
14
+ app = FastAPI(
15
+ title="Hydrogen Certification RAG System",
16
+ description="API for querying hydrogen certification documents using RAG",
17
+ version="0.1.0"
18
+ )
19
+
20
+ # Initialize LLMs and Elasticsearch client
21
+ llms = initialize_llms()
22
+
23
+ # Request models
24
+ class QueryRequest(BaseModel):
25
+ query: str
26
+
27
+
28
+
29
+ llm = initialize_llms()["rewrite_llm"]
30
+
31
+
32
+ # Endpoints
33
+ @app.post("/query")
34
+ async def handle_query(request: QueryRequest):
35
+ """
36
+ Process a query through the full RAG pipeline:
37
+ 1. Classify certification (if not provided)
38
+ 2. Optimize query based on specificity
39
+ 3. Search relevant chunks
40
+ """
41
+ try:
42
+ # Step 1: Determine certification
43
+ query = request.query
44
+ certification = classify_certification(request.query, llms["rewrite_llm"])
45
+ if "no certification mentioned" in certification :
46
+ raise HTTPException(
47
+ status_code=400,
48
+ detail="No certification specified in query and none provided"
49
+ )
50
+
51
+ # Step 2: Process query
52
+ processed_query = process_query(request.query, llms)
53
+ question_vector = generate_embeddings(processed_query)
54
+
55
+ # Step 3: Search
56
+ results = search_certification_chunks(
57
+ index_name="certif_index",
58
+ certification_name=certification,
59
+ text_query=processed_query,
60
+ )
61
+
62
+ results_ = search_certification_chunks(
63
+ index_name="certification_index",
64
+ certification_name=certification,
65
+ text_query=processed_query,
66
+ vector_query=question_vector,
67
+ )
68
+
69
+ results_merged = ". ".join([result["text"] for result in results])
70
+ results_merged_ = ". ".join([result["text"] for result in results_])
71
+
72
+ template = """
73
+ You are an AI assistant tasked with providing answers based on the given context about a specific hydrogen certification.
74
+
75
+ Provide a clear, concise response that directly addresses the question without unnecessary information.
76
+
77
+ Question: {question}
78
+ Certification: {certification}
79
+ Context: {context}
80
+
81
+ Answer:
82
+ """
83
+ prompt = PromptTemplate(
84
+ input_variables=["question", "certification", "context"],
85
+ template=template
86
+ )
87
+
88
+ chain = prompt | llm
89
+ answer = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged}).content
90
+ answer_ = chain.invoke({"question": processed_query, "certification": certification, "context": results_merged_}).content
91
+
92
+
93
+ return {
94
+ "certification": certification,
95
+ "certif_index": answer,
96
+ "certification_index": answer_,
97
+ }
98
+
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=str(e))
101
+
102
+ @app.get("/certifications", response_model=list[str])
103
+ async def list_certifications():
104
+ """List all available certifications"""
105
+ try:
106
+ certs_dir = "docs/processed"
107
+ return [f for f in os.listdir(certs_dir) if os.path.isdir(os.path.join(certs_dir, f))]
108
+ except Exception as e:
109
+ raise HTTPException(status_code=500, detail=str(e))
110
+
111
+ if __name__ == "__main__":
112
+ import uvicorn
113
+ uvicorn.run(app, host="0.0.0.0", port=8000)
chunking/semantic_chunking.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
5
+
6
+
7
+ def hybrid_split(text: str, max_len: int = 1024) -> list[str]:
8
+ """
9
+ Split text into chunks respecting sentence boundaries when possible,
10
+ with optional overlap between chunks.
11
+
12
+ Args:
13
+ text: The text to split
14
+ max_len: Maximum length for each chunk
15
+
16
+ Returns:
17
+ List of text chunks
18
+ """
19
+ # Normalize text
20
+ text = text.replace("\r", "").replace("\n", " ").strip()
21
+
22
+ # Extract sentences (more robust regex for sentence detection)
23
+ import re
24
+
25
+ sentences = re.split(r"(?<=[.!?])\s+", text)
26
+
27
+ chunks = []
28
+ current_chunk = ""
29
+
30
+ for sentence in sentences:
31
+ if len(sentence) > max_len:
32
+ # First add the current chunk if it exists
33
+ chunks.append(sentence)
34
+
35
+ # Normal case - see if adding the sentence exceeds max_len
36
+ elif len(current_chunk) + len(sentence) + 1 > max_len:
37
+ # Add the current chunk and start a new one
38
+ chunks.append(current_chunk)
39
+ current_chunk = ""
40
+ else:
41
+ # Add to the current chunk
42
+ if current_chunk:
43
+ current_chunk += " " + sentence
44
+ else:
45
+ current_chunk = sentence
46
+
47
+ if current_chunk:
48
+ chunks.append(current_chunk)
49
+
50
+ return chunks
51
+
52
+
53
+ def cosine_similarity(vec1, vec2):
54
+ """Calculate the cosine similarity between two vectors."""
55
+ dot_product = np.dot(vec1, vec2)
56
+ norm_vec1 = np.linalg.norm(vec1)
57
+ norm_vec2 = np.linalg.norm(vec2)
58
+ return dot_product / (norm_vec1 * norm_vec2)
59
+
60
+
61
+ def get_embedding(text):
62
+ """Generate an embedding using SBERT."""
63
+ return embedding_model.encode(text, convert_to_numpy=True)
64
+
65
+
66
+ def semantic_chunking(text, threshold=0.75, max_chunk_size=8191):
67
+ """
68
+ Splits text into semantic chunks based on sentence similarity.
69
+ - threshold: Lower = more splits, Higher = fewer splits
70
+ - max_chunk_size: Maximum size of each chunk in characters
71
+ """
72
+ text = text.replace("\n", " ").replace("\r", " ").strip()
73
+ sentences = hybrid_split(text)
74
+ embeddings = [get_embedding(sent) for sent in sentences]
75
+
76
+ chunks = []
77
+ current_chunk = [sentences[0]]
78
+
79
+ for i in range(1, len(sentences)):
80
+ sim = cosine_similarity(embeddings[i - 1], embeddings[i])
81
+ if (
82
+ sim < threshold
83
+ or len(" ".join(current_chunk + [sentences[i]])) > max_chunk_size
84
+ ):
85
+ chunks.append(" ".join(current_chunk))
86
+ current_chunk = [sentences[i]]
87
+ else:
88
+ current_chunk.append(sentences[i])
89
+
90
+ if current_chunk:
91
+ chunks.append(" ".join(current_chunk))
92
+
93
+ return chunks
elastic/es_client.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from elasticsearch import Elasticsearch, ConnectionError, AuthenticationException
4
+
5
+ # Configure logging at the application level
6
+ logger = logging.getLogger(__name__)
7
+ logging.basicConfig(level=logging.INFO)
8
+
9
+ # Load environment variables
10
+ ES_CLIENT_URL = os.getenv("ELASTICSEARCH_HOSTS", "http://localhost:9200")
11
+
12
+ class ElasticsearchClientError(Exception):
13
+ """Custom exception for Elasticsearch client errors."""
14
+ pass
15
+
16
+ def get_es_client() -> Elasticsearch:
17
+ """
18
+ Establish connection to Elasticsearch and return the client instance.
19
+ Raises ElasticsearchClientError if the connection cannot be established.
20
+ """
21
+
22
+ try:
23
+ print("es client", ES_CLIENT_URL)
24
+ # Initialize Elasticsearch client
25
+ es_client = Elasticsearch(
26
+ hosts=[ES_CLIENT_URL],
27
+ )
28
+
29
+ # Verify connection
30
+ if not es_client.ping():
31
+ error_message = "Elasticsearch cluster is not reachable!"
32
+ logger.error(error_message)
33
+ raise ElasticsearchClientError(error_message)
34
+
35
+ logger.info("Successfully connected to Elasticsearch")
36
+ return es_client
37
+
38
+ except (ConnectionError, AuthenticationException) as e:
39
+ error_message = f"Elasticsearch connection error: {e}"
40
+ logger.error(error_message)
41
+ raise ElasticsearchClientError(error_message) from e
elastic/es_index.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from elasticsearch import Elasticsearch, ConnectionError, AuthenticationException
4
+
5
+ # Configure logging at the application level
6
+ logger = logging.getLogger(__name__)
7
+ logging.basicConfig(level=logging.INFO)
8
+
9
+ # Load environment variables
10
+ ES_CLIENT_URL = os.getenv("ELASTICSEARCH_HOSTS")
11
+
12
+ class ElasticsearchClientError(Exception):
13
+ """Custom exception for Elasticsearch client errors."""
14
+ pass
15
+
16
+ def get_es_client() -> Elasticsearch:
17
+ """
18
+ Establish connection to Elasticsearch and return the client instance.
19
+ Raises ElasticsearchClientError if the connection cannot be established.
20
+ """
21
+
22
+ try:
23
+ print("es client", ES_CLIENT_URL)
24
+ # Initialize Elasticsearch client
25
+ es_client = Elasticsearch(
26
+ hosts=[ES_CLIENT_URL],
27
+ )
28
+
29
+ # Verify connection
30
+ if not es_client.ping():
31
+ error_message = "Elasticsearch cluster is not reachable!"
32
+ logger.error(error_message)
33
+ raise ElasticsearchClientError(error_message)
34
+
35
+ logger.info("Successfully connected to Elasticsearch")
36
+ return es_client
37
+
38
+ except (ConnectionError, AuthenticationException) as e:
39
+ error_message = f"Elasticsearch connection error: {e}"
40
+ logger.error(error_message)
41
+ raise ElasticsearchClientError(error_message) from e
elastic/indexing.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from elasticsearch import Elasticsearch, exceptions
3
+ from typing import Dict, Any
4
+
5
+
6
+ logger = logging.getLogger(__name__)
7
+ logging.basicConfig(level=logging.DEBUG)
8
+
9
+ embedding_dimension = 1536
10
+
11
+ def create_mapping(properties: Dict[str, Any]) -> Dict[str, Any]:
12
+ """Helper function to create index mappings with predefined settings."""
13
+ return {
14
+ "settings": {"number_of_shards": 1, "number_of_replicas": 1},
15
+ "mappings": {"properties": properties},
16
+ }
17
+
18
+
19
+ def retrieval_index() -> Dict[str, Any]:
20
+ """Returns the Elasticsearch mapping for retrieval indices."""
21
+ return create_mapping(
22
+ {
23
+ "chunk_id": {"type": "keyword"},
24
+ "chunk": {"type": "text"},
25
+ "embedding": {
26
+ "type": "dense_vector",
27
+ "dims": embedding_dimension,
28
+ },
29
+ "certification": {"type": "keyword"},
30
+ "source_file": {"type": "keyword"},
31
+ "timestamp": {"type": "date"},
32
+ }
33
+ )
34
+
35
+
36
+ def create_elasticsearch_index(es_client: Elasticsearch, index_name: str) -> bool:
37
+ """
38
+ Create an Elasticsearch index with the appropriate mapping.
39
+
40
+ Args:
41
+ es_client (Elasticsearch): The Elasticsearch client instance.
42
+ index_name (str): The name of the index to create.
43
+
44
+ Returns:
45
+ bool: True if the index was created successfully, False otherwise.
46
+ """
47
+ try:
48
+ mapping = retrieval_index()
49
+
50
+ if es_client.indices.exists(index=index_name):
51
+ logger.warning(f"Index '{index_name}' already exists. Skipping creation.")
52
+ return True
53
+
54
+ es_client.indices.create(index=index_name, body=mapping)
55
+ logger.info(f"Index '{index_name}' created successfully.")
56
+ return True
57
+
58
+ except Exception as e:
59
+ logger.error(f"Unexpected error while creating index '{index_name}': {e}")
60
+ return False
elastic/retrieval.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ import logging
3
+ from elasticsearch import exceptions
4
+
5
+ from elastic.es_client import get_es_client
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ es_client = get_es_client()
10
+
11
+
12
+ def search_certification_chunks(
13
+ index_name: str,
14
+ text_query: str,
15
+ vector_query: List[float],
16
+ certification_name: str,
17
+ es_client=es_client,
18
+ vector_field: str = "embedding",
19
+ text_field: str = "chunk",
20
+ size: int = 5,
21
+ min_score: float = 0.1, # Lowered threshold
22
+ boost_text: float = 1.0,
23
+ boost_vector: float = 1.0,
24
+ ) -> List[Dict[str, Any]]:
25
+
26
+ # First verify the certification value exists
27
+ cert_check = es_client.search(
28
+ index=index_name,
29
+ body={
30
+ "query": {"term": {"certification": certification_name}},
31
+ "size": 1,
32
+ },
33
+ )
34
+
35
+ if not cert_check["hits"]["hits"]:
36
+ logger.error(f"No documents found with certification: {certification_name}")
37
+ return []
38
+
39
+ # Then proceed with hybrid search
40
+ query_body = {
41
+ "size": size,
42
+ "query": {
43
+ "bool": {
44
+ "should": [
45
+ {"match": {"chunk": text_query}},
46
+ {
47
+ "script_score": {
48
+ "query": {"match_all": {}},
49
+ "script": {
50
+ "source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
51
+ "params": {"query_vector": vector_query},
52
+ },
53
+ }
54
+ },
55
+ ]
56
+ }
57
+ },
58
+ }
59
+ logger.debug(f"Elasticsearch query body: {query_body}")
60
+
61
+ logger.info(f"Executing search on index '{index_name}'")
62
+ response = es_client.search(index=index_name, body=query_body, routing=cert_check["hits"]["hits"][0]["_id"])
63
+ hits = response.get("hits", {}).get("hits", [])
64
+ logger.info(f"Found {len(hits)} matching documents")
65
+
66
+ # Process results with correct field names
67
+ results = [
68
+ {
69
+ "id": hit["_id"],
70
+ "score": hit["_score"],
71
+ "text": hit["_source"]["chunk"],
72
+ "source_file": hit["_source"]["source_file"],
73
+ }
74
+ for hit in hits
75
+ ]
76
+
77
+ if results:
78
+ logger.debug(f"Top result score: {results[0]['score']}")
79
+ logger.debug(f"Top result source: {results[0]['source_file']}")
80
+ else:
81
+ logger.warning("No results returned from Elasticsearch")
82
+
83
+ return results
embeddings/embeddings.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List
4
+ from openai import OpenAI
5
+ from dotenv import load_dotenv
6
+
7
+ # Configure logging at the application level
8
+ logger = logging.getLogger(__name__)
9
+ logging.basicConfig(level=logging.INFO)
10
+
11
+ load_dotenv()
12
+
13
+ embedding_dimension = 1536
14
+
15
+ model = "text-embedding-3-small"
16
+
17
+ openai_api_key = os.getenv("OPENAI_API_KEY").strip().strip("\n")
18
+ client = OpenAI(api_key=openai_api_key)
19
+
20
+
21
+ def generate_embeddings(text: str) -> List[float]:
22
+ """Get embeddings from OpenAI API."""
23
+ logging.info("Embedding model: %s", model)
24
+
25
+ try:
26
+ if text:
27
+ response = client.embeddings.create(
28
+ model=model, input=text, dimensions=embedding_dimension
29
+ )
30
+ return response.data[0].embedding
31
+ except Exception as e:
32
+ logger.error(f"OpenAI API error: {e}")
33
+ raise
es_data/node.lock ADDED
File without changes
es_data/nodes ADDED
@@ -0,0 +1 @@
 
 
1
+ written by Elasticsearch v8.8.0 to prevent a downgrade to a version prior to v8.0.0 which would result in data loss
file_processing.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdfplumber
2
+ from docx import Document
3
+ from openpyxl import load_workbook
4
+ import pdfplumber
5
+ import logging
6
+ from typing import List, Union, Tuple
7
+ import os
8
+
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def extract_pdf_content(pdf_path: str) -> List[str]:
16
+ """
17
+ Extract text and tables from PDF in their natural reading order.
18
+ Simplified version without positional processing.
19
+
20
+ Args:
21
+ pdf_path (str): Path to the PDF file
22
+
23
+ Returns:
24
+ List[str]: List of extracted content chunks (text and tables)
25
+ """
26
+ if not os.path.exists(pdf_path):
27
+ logger.error(f"PDF file not found: {pdf_path}")
28
+ return []
29
+
30
+ try:
31
+ with pdfplumber.open(pdf_path) as pdf:
32
+ content = []
33
+
34
+ for page in pdf.pages:
35
+ # First extract tables
36
+ tables = page.extract_tables()
37
+ for table in tables:
38
+ if table:
39
+ # Convert table to string representation
40
+ table_str = "\n".join(
41
+ ["\t".join(str(cell) for cell in row) for row in table]
42
+ )
43
+ content.append(f"[TABLE]\n{table_str}\n[/TABLE]")
44
+
45
+ # Then extract regular text
46
+ text = page.extract_text()
47
+ if text and text.strip():
48
+ content.append(text.strip())
49
+
50
+ logger.info(f"Successfully extracted content from {pdf_path}")
51
+ return content
52
+
53
+ except Exception as e:
54
+ logger.error(f"Error processing {pdf_path}: {str(e)}")
55
+ return []
56
+
57
+
58
+ from docx import Document
59
+ from typing import List
60
+ import os
61
+
62
+ def extract_docx_content(docx_path: str) -> List[str]:
63
+ """
64
+ Extract text and tables from DOCX file with clear table markers.
65
+
66
+ Args:
67
+ docx_path (str): Path to the DOCX file
68
+
69
+ Returns:
70
+ List[str]: List of extracted content chunks with tables marked as [TABLE]...[/TABLE]
71
+ """
72
+ if not os.path.exists(docx_path):
73
+ raise FileNotFoundError(f"DOCX file not found: {docx_path}")
74
+
75
+ doc = Document(docx_path)
76
+ content = []
77
+
78
+ # Process all paragraphs first
79
+ for paragraph in doc.paragraphs:
80
+ text = paragraph.text.strip()
81
+ if text:
82
+ content.append(text)
83
+
84
+ # Process all tables after paragraphs
85
+ for table in doc.tables:
86
+ table_str = "\n".join(
87
+ ["\t".join(cell.text.strip() for cell in row.cells)
88
+ for row in table.rows]
89
+ )
90
+ if table_str.strip():
91
+ content.append(f"[TABLE]\n{table_str}\n[/TABLE]")
92
+
93
+ return content
94
+
95
+ def extract_xlsx_content(file_path: str):
96
+ wb = load_workbook(file_path)
97
+ sheets_text = []
98
+
99
+ for sheet in wb:
100
+ sheet_str = f"--- Sheet: {sheet.title} ---\n"
101
+ for row in sheet.iter_rows():
102
+ row_str = "\t".join(str(cell.value) if cell.value else "" for cell in row)
103
+ sheet_str += row_str + "\n"
104
+ sheets_text.append(sheet_str.strip())
105
+
106
+ return sheets_text
ingestion.py ADDED
File without changes
prompting/rewrite_question.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_groq import ChatGroq
5
+ from typing import Literal
6
+
7
+ # Load environment variables
8
+ load_dotenv()
9
+
10
+ # Initialize LLMs
11
+ def initialize_llms():
12
+ """Initialize and return the LLM instances"""
13
+ groq_api_key = os.getenv("GROQ_API_KEY")
14
+
15
+ return {
16
+ "rewrite_llm": ChatGroq(
17
+ temperature=0.1,
18
+ model="llama-3.3-70b-versatile",
19
+ api_key=groq_api_key
20
+ ),
21
+ "step_back_llm": ChatGroq(
22
+ temperature=0,
23
+ model="Gemma2-9B-IT",
24
+ api_key=groq_api_key
25
+ )
26
+ }
27
+
28
+ # Certification classification
29
+ def classify_certification(
30
+ query: str,
31
+ llm: ChatGroq,
32
+ certs_dir: str = "docs/processed"
33
+ ) -> str:
34
+ """
35
+ Classify which certification a query is referring to.
36
+ Returns certification name or 'no certification mentioned'.
37
+ """
38
+ available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation"
39
+
40
+ template = """
41
+ You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system.
42
+ Classify the given query into one of the following certifications:
43
+ - {available_certifications}
44
+
45
+ Don't need any explanation, just return the name of the certification.
46
+
47
+ Use the exact name of the certification as it appears in the directory.
48
+ If the query refers to multiple certifications, return the most relevant one.
49
+
50
+ If the query doesn't mention any certification, respond with "no certification mentioned".
51
+
52
+ Original query: {original_query}
53
+
54
+ Classification:
55
+ """
56
+
57
+ prompt = PromptTemplate(
58
+ input_variables=["original_query", "available_certifications"],
59
+ template=template
60
+ )
61
+
62
+ chain = prompt | llm
63
+ response = chain.invoke({
64
+ "original_query": query,
65
+ "available_certifications": available_certs
66
+ }).content.strip()
67
+
68
+ return response
69
+
70
+ # Query specificity classification
71
+ def classify_query_specificity(
72
+ query: str,
73
+ llm: ChatGroq
74
+ ) -> Literal["specific", "general", "too narrow"]:
75
+ """
76
+ Classify query specificity.
77
+ Returns one of: 'specific', 'general', or 'too narrow'.
78
+ """
79
+ template = """
80
+ You are an AI assistant classifying user queries based on their specificity for a RAG system.
81
+ Classify the given query into one of:
82
+ - "specific" → If it asks for exact values, certifications, or well-defined facts.
83
+ - "general" → If it is broad and needs refinement for better retrieval.
84
+ - "too narrow" → If it is very specific and might need broader context.
85
+ DO NOT output explanations, only return one of: "specific", "general", or "too narrow".
86
+
87
+ Original query: {original_query}
88
+
89
+ Classification:
90
+ """
91
+
92
+ prompt = PromptTemplate(
93
+ input_variables=["original_query"],
94
+ template=template
95
+ )
96
+
97
+ chain = prompt | llm
98
+ response = chain.invoke({"original_query": query}).content.strip().lower()
99
+ return response.split("\n")[0].strip() # type: ignore
100
+
101
+ # Query refinement
102
+ def refine_query(
103
+ query: str,
104
+ llm: ChatGroq
105
+ ) -> str:
106
+ """Rewrite a query to be clearer and more detailed while keeping the original intent"""
107
+ template = """
108
+ You are an AI assistant that improves queries for retrieving precise certification and compliance data.
109
+ Rewrite the query to be clearer while keeping the intent unchanged.
110
+
111
+ Original query: {original_query}
112
+
113
+ Refined query:
114
+ """
115
+
116
+ prompt = PromptTemplate(
117
+ input_variables=["original_query"],
118
+ template=template
119
+ )
120
+
121
+ chain = prompt | llm
122
+ return chain.invoke({"original_query": query}).content
123
+
124
+ # Step-back query generation
125
+ def generate_step_back_query(
126
+ query: str,
127
+ llm: ChatGroq
128
+ ) -> str:
129
+ """Generate a broader step-back query to retrieve relevant background information"""
130
+ template = """
131
+ You are an AI assistant generating broader queries to improve retrieval context.
132
+ Given the original query, generate a more general step-back query to retrieve relevant background information.
133
+
134
+ Original query: {original_query}
135
+
136
+ Step-back query:
137
+ """
138
+
139
+ prompt = PromptTemplate(
140
+ input_variables=["original_query"],
141
+ template=template
142
+ )
143
+
144
+ chain = prompt | llm
145
+ return chain.invoke({"original_query": query}).content
146
+
147
+ # Main query processing pipeline
148
+ def process_query(
149
+ original_query: str,
150
+ llms: dict
151
+ ) -> str:
152
+ """
153
+ Process a query through the full pipeline:
154
+ 1. Classify specificity
155
+ 2. Apply appropriate refinement
156
+ """
157
+ specificity = classify_query_specificity(original_query, llms["rewrite_llm"])
158
+
159
+ if specificity == "specific":
160
+ return refine_query(original_query, llms["rewrite_llm"])
161
+ elif specificity == "general":
162
+ return refine_query(original_query, llms["rewrite_llm"])
163
+ elif specificity == "too narrow":
164
+ return generate_step_back_query(original_query, llms["step_back_llm"])
165
+ return original_query
166
+
167
+ # Test setup
168
+ def test_hydrogen_certification_functions():
169
+ # Initialize LLMs
170
+ llms = initialize_llms()
171
+
172
+ # Create a test directory with hydrogen certifications
173
+ test_certs_dir = "docs/processed"
174
+ os.makedirs(test_certs_dir, exist_ok=True)
175
+
176
+ # Create some dummy certification folders
177
+ hydrogen_certifications = [
178
+ "GH2_Standard",
179
+ "Certified_Hydrogen_Producer",
180
+ "Green_Hydrogen_Certification",
181
+ "ISO_19880_Hydrogen_Quality"
182
+ ]
183
+
184
+ for cert in hydrogen_certifications:
185
+ os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True)
186
+
187
+ # Test queries
188
+ test_queries = [
189
+ ("What are the purity requirements in GH2 Standard?", "specific"),
190
+ ("How does hydrogen certification work?", "general"),
191
+ ("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"),
192
+ ("What safety protocols exist for hydrogen storage?", "general")
193
+ ]
194
+
195
+ print("=== Testing Certification Classification ===")
196
+ for query, _ in test_queries:
197
+ cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir)
198
+ print(f"Query: {query}\nClassification: {cert}\n")
199
+
200
+ print("\n=== Testing Specificity Classification ===")
201
+ for query, expected_type in test_queries:
202
+ specificity = classify_query_specificity(query, llms["rewrite_llm"])
203
+ print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n")
204
+
205
+ print("\n=== Testing Full Query Processing ===")
206
+ for query, _ in test_queries:
207
+ processed = process_query(query, llms)
208
+ print(f"Original: {query}\nProcessed: {processed}\n")
209
+
210
+ # Run the tests
211
+ if __name__ == "__main__":
212
+ test_hydrogen_certification_functions()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-groq
3
+ langchain-community
4
+ chromadb
5
+ jq
6
+ fastembed
7
+ python-dotenv
8
+ langchain_chroma
9
+ unstructured
10
+ openai
11
+ elastic-transport==8.17.0
12
+ elasticsearch==8.17.1
13
+ sentence-transformers
14
+ fastapi
15
+ pdfplumber
16
+ pdfminer.six
17
+ python-docx
18
+ openpyxl
19
+ PyPDF2
20
+ streamlit
21
+ uvicorn
streamlit.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import logging
3
+ import asyncio
4
+ from contextlib import asynccontextmanager
5
+ from app import QueryRequest # Import the request model
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Set page config
12
+ st.set_page_config(page_title="Certification Chat", layout="centered")
13
+
14
+ st.title("🎓 Certification Chat Assistant")
15
+
16
+ # Create a function to handle the async call
17
+ async def async_query(query_text):
18
+ from app import handle_query # Import here to avoid circular imports
19
+ request = QueryRequest(query=query_text)
20
+ return await handle_query(request)
21
+
22
+ # Function to run async code in Streamlit
23
+ def run_async(coroutine):
24
+ try:
25
+ loop = asyncio.get_event_loop()
26
+ except RuntimeError:
27
+ loop = asyncio.new_event_loop()
28
+ asyncio.set_event_loop(loop)
29
+ return loop.run_until_complete(coroutine)
30
+
31
+ # User input
32
+ user_input = st.text_input("💬 Enter your prompt:")
33
+
34
+ if user_input:
35
+ st.markdown("## 🧠 Response")
36
+
37
+ try:
38
+ # Use try-except to handle errors
39
+ with st.spinner("Processing your query..."):
40
+ # Run the async function
41
+ result = run_async(async_query(user_input))
42
+
43
+ # Display output
44
+ st.write("**Certification:**", result["certification"])
45
+ st.write("**Answer from certif_index:**", result["certif_index"])
46
+ st.write("**Answer from certification_index:**", result["certification_index"])
47
+
48
+ except Exception as e:
49
+ st.error(f"An error occurred: {str(e)}")
50
+ logger.error(f"Error processing query: {e}", exc_info=True)