benjika commited on
Commit
7539883
Β·
verified Β·
1 Parent(s): 217741c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +140 -147
  2. chroma.zip +3 -0
  3. requirements.txt +9 -11
app.py CHANGED
@@ -1,147 +1,140 @@
1
- import os
2
- import chromadb
3
- import gradio as gr
4
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
5
- from langchain_chroma import Chroma
6
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
7
- from langchain_community.document_loaders import PyPDFLoader
8
- from langchain.text_splitter import RecursiveCharacterTextSplitter
9
- from langchain.chains import create_retrieval_chain, LLMChain
10
- from langchain.prompts import PromptTemplate
11
- from collections import OrderedDict
12
-
13
- # Load embeddings model
14
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
15
-
16
- # Load Chroma database (Avoid reprocessing documents)
17
- CHROMA_PATH = "./chroma_db"
18
- if not os.path.exists(CHROMA_PATH):
19
- raise FileNotFoundError("ChromaDB folder not found. Make sure it's uploaded to the repo.")
20
-
21
- chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
22
- db = Chroma(embedding_function=embeddings, client=chroma_client)
23
-
24
- # Load the model
25
- model_name = "google/flan-t5-large"
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
28
-
29
- # Create pipeline
30
- qa_pipeline = pipeline(
31
- "text2text-generation",
32
- model=model,
33
- tokenizer=tokenizer,
34
- device=0,
35
- max_length=512,
36
- min_length=50,
37
- do_sample=False,
38
- repetition_penalty=1.2
39
- )
40
-
41
- # Wrap pipeline in LangChain
42
- llm = HuggingFacePipeline(pipeline=qa_pipeline)
43
- retriever = db.as_retriever(search_kwargs={"k": 3})
44
-
45
-
46
- def clean_context(context_list, max_tokens=350, min_length=50):
47
- """
48
- Cleans retrieved document context:
49
- - Removes duplicates while preserving order
50
- - Limits total token count
51
- - Ensures useful, non-repetitive context
52
- """
53
-
54
- # Preserve order while removing duplicates
55
- unique_texts = list(OrderedDict.fromkeys([doc.page_content.strip() for doc in context_list]))
56
-
57
- # Remove very short texts (e.g., headers)
58
- filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
59
-
60
- # Avoid near-duplicate entries
61
- deduplicated_texts = []
62
- seen_texts = set()
63
- for text in filtered_texts:
64
- if not any(text in s for s in seen_texts): # Avoid near-duplicates
65
- deduplicated_texts.append(text)
66
- seen_texts.add(text)
67
-
68
- # Limit context based on token count
69
- trimmed_context = []
70
- total_tokens = 0
71
- for text in deduplicated_texts:
72
- tokenized_text = tokenizer.encode(text, add_special_tokens=False)
73
- token_count = len(tokenized_text)
74
-
75
- if total_tokens + token_count > max_tokens:
76
- remaining_tokens = max_tokens - total_tokens
77
- if remaining_tokens > 20:
78
- trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
79
- break
80
-
81
- trimmed_context.append(text)
82
- total_tokens += token_count
83
-
84
- return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
85
-
86
- # Define prompt
87
- prompt_template = PromptTemplate(
88
- template="""
89
- You are a Kubernetes instructor. Answer the question based on the provided context.
90
- If the context does not provide an answer, say "I don't have enough information."
91
-
92
- Context:
93
- {context}
94
-
95
- Question:
96
- {input}
97
-
98
- Answer:
99
- """,
100
- input_variables=["context", "input"]
101
- )
102
-
103
- llm_chain = LLMChain(llm=llm, prompt=prompt_template)
104
- qa_chain = create_retrieval_chain(retriever, llm_chain)
105
-
106
- # Query function
107
- def get_k8s_answer(query):
108
- retrieved_context = retriever.get_relevant_documents(query)
109
- cleaned_context = clean_context(retrieved_context, max_tokens=350) # Ensure context size is within limits
110
-
111
- # Ensure total input tokens < 512 before passing to model
112
- input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
113
- total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
114
-
115
- if total_tokens > 512:
116
- # Trim context further to fit within the limit
117
- allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens for the model's response
118
- cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)
119
-
120
- # Recalculate total tokens
121
- input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
122
- total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
123
-
124
- if total_tokens > 512:
125
- return "Error: Even after trimming, input is too large."
126
-
127
- response = qa_chain.invoke({"input": query, "context": cleaned_context})
128
- return response
129
-
130
- def get_k8s_answer_text(query):
131
- model_full_answer = get_k8s_answer(query)
132
- if 'answer' in model_full_answer.keys():
133
- if 'text' in model_full_answer['answer'].keys():
134
- return model_full_answer['answer']['text']
135
- return "Error"
136
-
137
- # Gradio Interface
138
- demo = gr.Interface(
139
- fn=get_k8s_answer_text,
140
- inputs=gr.Textbox(label="Ask a Kubernetes Question"),
141
- outputs=gr.Textbox(label="Answer"),
142
- title="Kubernetes RAG Assistant",
143
- description="Ask any Kubernetes-related question and get a step-by-step answer based on documentation."
144
- )
145
-
146
- if __name__ == "__main__":
147
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
4
+ import chromadb
5
+ from langchain.vectorstores import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
7
+ from langchain.chains import create_retrieval_chain, LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ import os
10
+ import shutil
11
+ import zipfile
12
+
13
+ # πŸš€ Step 1: Extract ChromaDB if not already done (only once)
14
+ if not os.path.exists("./chroma_db"):
15
+ with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
16
+ zip_ref.extractall("./chroma_db")
17
+
18
+ # πŸš€ Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
19
+ MODEL_NAME = "google/flan-t5-xl"
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
22
+
23
+ # πŸš€ Step 3: Load Vector Store Efficiently
24
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
25
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
26
+ db = Chroma(embedding_function=embeddings, client=chroma_client)
27
+
28
+ # πŸš€ Step 4: Optimize Retriever (Lower `k` for Speed)
29
+ retriever = db.as_retriever(search_kwargs={"k": 10})
30
+
31
+ # πŸš€ Step 5: Define Prompt for the LLM
32
+ prompt_template = PromptTemplate(
33
+ template="""
34
+ You are a Kubernetes expert.
35
+ **Answer the question using ONLY the provided context.**
36
+ If the context does NOT contain enough information, return:
37
+ `"I don't have enough information to answer this question."`
38
+ Always include YAML examples when relevant.
39
+
40
+ ---
41
+ **Context:**
42
+ {context}
43
+
44
+ **Question:**
45
+ {input}
46
+
47
+ ---
48
+ **Answer:**
49
+ """,
50
+ input_variables=["context", "input"]
51
+ )
52
+
53
+ # πŸš€ Step 6: Build Retrieval Chain
54
+ qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
55
+ max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
56
+ llm = HuggingFacePipeline(pipeline=qa_pipeline)
57
+ llm_chain = LLMChain(llm=llm, prompt=prompt_template)
58
+ qa_chain = create_retrieval_chain(retriever, llm_chain)
59
+
60
+ # πŸš€ Step 7: Define Fast Answer Function
61
+ def clean_context(context_list, max_tokens=350, min_length=50):
62
+ """
63
+ Improves the retrieved document context:
64
+ - Removes duplicates while preserving order
65
+ - Filters out very short or unstructured text
66
+ - Limits token count for better LLM performance
67
+ """
68
+ from collections import OrderedDict
69
+
70
+ # Preserve order while removing exact duplicates
71
+ unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))
72
+
73
+ # Remove very short texts (e.g., headers, page numbers)
74
+ filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
75
+
76
+ # Avoid near-duplicates
77
+ deduplicated_texts = []
78
+ seen_texts = set()
79
+ for text in filtered_texts:
80
+ normalized_text = " ".join(text.split()) # Normalize spacing
81
+ if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates
82
+ deduplicated_texts.append(normalized_text)
83
+ seen_texts.add(normalized_text)
84
+
85
+ # Limit context based on token count
86
+ trimmed_context = []
87
+ total_tokens = 0
88
+ for text in deduplicated_texts:
89
+ tokenized_text = tokenizer.encode(text, add_special_tokens=False)
90
+ token_count = len(tokenized_text)
91
+
92
+ if total_tokens + token_count > max_tokens:
93
+ remaining_tokens = max_tokens - total_tokens
94
+ if remaining_tokens > 20: # Allow partial inclusion if meaningful
95
+ trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
96
+ break
97
+
98
+ trimmed_context.append(text)
99
+ total_tokens += token_count
100
+
101
+ return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
102
+
103
+
104
+ def get_k8s_answer(query):
105
+ retrieved_context = retriever.invoke(query)
106
+ cleaned_context = clean_context(retrieved_context, max_tokens=350)
107
+
108
+ # Ensure input tokens fit within 512 limit
109
+ input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
110
+ total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
111
+
112
+ if total_tokens > 512:
113
+ # Further trim context
114
+ allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens reserved for response
115
+ cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)
116
+
117
+ # Recalculate total tokens
118
+ input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
119
+ total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
120
+
121
+ if total_tokens > 512:
122
+ return "Error: Even after trimming, input is too large."
123
+
124
+ response = qa_chain.invoke({"input": query, "context": cleaned_context})
125
+ return response
126
+
127
+ # πŸš€ Step 8: Optimize Gradio App with `Blocks()`
128
+ with gr.Blocks(theme="soft") as demo:
129
+ gr.Markdown("# ⚑ Kubernetes RAG")
130
+ gr.Markdown("Ask any Kubernetes-related question!")
131
+
132
+ with gr.Row():
133
+ question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
134
+ answer = gr.Textbox(label="Answer", interactive=False)
135
+
136
+ submit_button = gr.Button("Get Answer")
137
+
138
+ submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)
139
+
140
+ demo.launch()
 
 
 
 
 
 
 
chroma.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1cb847a7f5e922fead2197320f50734db00b6280e5acc8c202317b67f484e46a
3
+ size 126566892
requirements.txt CHANGED
@@ -1,11 +1,9 @@
1
- huggingface_hub==0.25.2
2
- gradio
3
- transformers
4
- sentence-transformers
5
- chromadb
6
- pypdf
7
- torch
8
- langchain
9
- langchain-huggingface
10
- langchain-chroma
11
- langchain_community
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ chromadb
5
+ langchain
6
+ langchain-chroma
7
+ langchain-community
8
+ langchain-huggingface
9
+ sentence-transformers