benjika commited on
Commit
7cb0a3b
Β·
verified Β·
1 Parent(s): f2bbf99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -140
app.py CHANGED
@@ -1,140 +1,129 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
4
- import chromadb
5
- from langchain.vectorstores import Chroma
6
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
7
- from langchain.chains import create_retrieval_chain, LLMChain
8
- from langchain.prompts import PromptTemplate
9
- import os
10
- import shutil
11
- import zipfile
12
-
13
- # πŸš€ Step 1: Extract ChromaDB if not already done (only once)
14
- if not os.path.exists("./chroma_db"):
15
- with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
16
- zip_ref.extractall("./chroma_db")
17
-
18
- # πŸš€ Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
19
- MODEL_NAME = "google/flan-t5-xl"
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
22
-
23
- # πŸš€ Step 3: Load Vector Store Efficiently
24
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
25
- chroma_client = chromadb.PersistentClient(path="./chroma_db")
26
- db = Chroma(embedding_function=embeddings, client=chroma_client)
27
-
28
- # πŸš€ Step 4: Optimize Retriever (Lower `k` for Speed)
29
- retriever = db.as_retriever(search_kwargs={"k": 10})
30
-
31
- # πŸš€ Step 5: Define Prompt for the LLM
32
- prompt_template = PromptTemplate(
33
- template="""
34
- You are a Kubernetes expert.
35
- **Answer the question using ONLY the provided context.**
36
- If the context does NOT contain enough information, return:
37
- `"I don't have enough information to answer this question."`
38
- Always include YAML examples when relevant.
39
-
40
- ---
41
- **Context:**
42
- {context}
43
-
44
- **Question:**
45
- {input}
46
-
47
- ---
48
- **Answer:**
49
- """,
50
- input_variables=["context", "input"]
51
- )
52
-
53
- # πŸš€ Step 6: Build Retrieval Chain
54
- qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
55
- max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
56
- llm = HuggingFacePipeline(pipeline=qa_pipeline)
57
- llm_chain = LLMChain(llm=llm, prompt=prompt_template)
58
- qa_chain = create_retrieval_chain(retriever, llm_chain)
59
-
60
- # πŸš€ Step 7: Define Fast Answer Function
61
- def clean_context(context_list, max_tokens=350, min_length=50):
62
- """
63
- Improves the retrieved document context:
64
- - Removes duplicates while preserving order
65
- - Filters out very short or unstructured text
66
- - Limits token count for better LLM performance
67
- """
68
- from collections import OrderedDict
69
-
70
- # Preserve order while removing exact duplicates
71
- unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))
72
-
73
- # Remove very short texts (e.g., headers, page numbers)
74
- filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
75
-
76
- # Avoid near-duplicates
77
- deduplicated_texts = []
78
- seen_texts = set()
79
- for text in filtered_texts:
80
- normalized_text = " ".join(text.split()) # Normalize spacing
81
- if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates
82
- deduplicated_texts.append(normalized_text)
83
- seen_texts.add(normalized_text)
84
-
85
- # Limit context based on token count
86
- trimmed_context = []
87
- total_tokens = 0
88
- for text in deduplicated_texts:
89
- tokenized_text = tokenizer.encode(text, add_special_tokens=False)
90
- token_count = len(tokenized_text)
91
-
92
- if total_tokens + token_count > max_tokens:
93
- remaining_tokens = max_tokens - total_tokens
94
- if remaining_tokens > 20: # Allow partial inclusion if meaningful
95
- trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
96
- break
97
-
98
- trimmed_context.append(text)
99
- total_tokens += token_count
100
-
101
- return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
102
-
103
-
104
- def get_k8s_answer(query):
105
- retrieved_context = retriever.invoke(query)
106
- cleaned_context = clean_context(retrieved_context, max_tokens=350)
107
-
108
- # Ensure input tokens fit within 512 limit
109
- input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
110
- total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
111
-
112
- if total_tokens > 512:
113
- # Further trim context
114
- allowed_tokens = 512 - len(tokenizer.encode(query, add_special_tokens=True)) - 50 # 50 tokens reserved for response
115
- cleaned_context = clean_context(retrieved_context, max_tokens=allowed_tokens)
116
-
117
- # Recalculate total tokens
118
- input_text = f"Context:\n{cleaned_context}\n\nQuestion: {query}\nAnswer:"
119
- total_tokens = len(tokenizer.encode(input_text, add_special_tokens=True))
120
-
121
- if total_tokens > 512:
122
- return "Error: Even after trimming, input is too large."
123
-
124
- response = qa_chain.invoke({"input": query, "context": cleaned_context})
125
- return response
126
-
127
- # πŸš€ Step 8: Optimize Gradio App with `Blocks()`
128
- with gr.Blocks(theme="soft") as demo:
129
- gr.Markdown("# ⚑ Kubernetes RAG")
130
- gr.Markdown("Ask any Kubernetes-related question!")
131
-
132
- with gr.Row():
133
- question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
134
- answer = gr.Textbox(label="Answer", interactive=False)
135
-
136
- submit_button = gr.Button("Get Answer")
137
-
138
- submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)
139
-
140
- demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
4
+ import chromadb
5
+ from langchain.vectorstores import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
7
+ from langchain.chains import create_retrieval_chain, LLMChain
8
+ from langchain.prompts import PromptTemplate
9
+ import os
10
+ import shutil
11
+ import zipfile
12
+
13
+ # πŸš€ Step 1: Extract ChromaDB if not already done (only once)
14
+ if not os.path.exists("./chroma_db"):
15
+ with zipfile.ZipFile("chroma.zip", "r") as zip_ref:
16
+ zip_ref.extractall("./chroma_db")
17
+
18
+ # πŸš€ Step 2: Load Pre-trained Model & Tokenizer (Fast Startup)
19
+ MODEL_NAME = "google/flan-t5-xl"
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
22
+
23
+ # πŸš€ Step 3: Load Vector Store Efficiently
24
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
25
+ chroma_client = chromadb.PersistentClient(path="./chroma_db")
26
+ db = Chroma(embedding_function=embeddings, client=chroma_client)
27
+
28
+ # πŸš€ Step 4: Optimize Retriever (Lower `k` for Speed)
29
+ retriever = db.as_retriever(search_kwargs={"k": 10})
30
+
31
+ # πŸš€ Step 5: Define Prompt for the LLM
32
+ prompt_template = PromptTemplate(
33
+ template="""
34
+ You are a Kubernetes expert.
35
+ **Answer the question using ONLY the provided context.**
36
+ If the context does NOT contain enough information, return:
37
+ `"I don't have enough information to answer this question."`
38
+ Always include YAML examples when relevant.
39
+
40
+ ---
41
+ **Context:**
42
+ {context}
43
+
44
+ **Question:**
45
+ {input}
46
+
47
+ ---
48
+ **Answer:**
49
+ """,
50
+ input_variables=["context", "input"]
51
+ )
52
+
53
+ # πŸš€ Step 6: Build Retrieval Chain
54
+ qa_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0,
55
+ max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
56
+ llm = HuggingFacePipeline(pipeline=qa_pipeline)
57
+ llm_chain = LLMChain(llm=llm, prompt=prompt_template)
58
+ qa_chain = create_retrieval_chain(retriever, llm_chain)
59
+
60
+ # πŸš€ Step 7: Define Fast Answer Function
61
+ def clean_context(context_list, max_tokens=350, min_length=50):
62
+ """
63
+ Improves the retrieved document context:
64
+ - Removes duplicates while preserving order
65
+ - Filters out very short or unstructured text
66
+ - Limits token count for better LLM performance
67
+ """
68
+ from collections import OrderedDict
69
+
70
+ # Preserve order while removing exact duplicates
71
+ unique_texts = list(OrderedDict.fromkeys(doc.page_content.strip() for doc in context_list))
72
+
73
+ # Remove very short texts (e.g., headers, page numbers)
74
+ filtered_texts = [text for text in unique_texts if len(text.split()) > min_length]
75
+
76
+ # Avoid near-duplicates
77
+ deduplicated_texts = []
78
+ seen_texts = set()
79
+ for text in filtered_texts:
80
+ normalized_text = " ".join(text.split()) # Normalize spacing
81
+ if not any(normalized_text in seen for seen in seen_texts): # Avoid near-duplicates
82
+ deduplicated_texts.append(normalized_text)
83
+ seen_texts.add(normalized_text)
84
+
85
+ # Limit context based on token count
86
+ trimmed_context = []
87
+ total_tokens = 0
88
+ for text in deduplicated_texts:
89
+ tokenized_text = tokenizer.encode(text, add_special_tokens=False)
90
+ token_count = len(tokenized_text)
91
+
92
+ if total_tokens + token_count > max_tokens:
93
+ remaining_tokens = max_tokens - total_tokens
94
+ if remaining_tokens > 20: # Allow partial inclusion if meaningful
95
+ trimmed_context.append(tokenizer.decode(tokenized_text[:remaining_tokens]))
96
+ break
97
+
98
+ trimmed_context.append(text)
99
+ total_tokens += token_count
100
+
101
+ return "\n\n".join(trimmed_context) if trimmed_context else "No relevant context found."
102
+
103
+
104
+ def get_k8s_answer(query):
105
+ retrieved_context = retriever.get_relevant_documents(query)
106
+ cleaned_context = clean_context(retrieved_context, max_tokens=350)
107
+
108
+ input_text = prompt_template.format(context=cleaned_context, input=query)
109
+
110
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).to(model.device)
111
+ output_ids = model.generate(**inputs, max_length=512, min_length=50, do_sample=True, temperature=0.4, top_p=0.9)
112
+
113
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
114
+ return response
115
+
116
+ # πŸš€ Step 8: Optimize Gradio App with `Blocks()`
117
+ with gr.Blocks(theme="soft") as demo:
118
+ gr.Markdown("# ⚑ Kubernetes RAG")
119
+ gr.Markdown("Ask any Kubernetes-related question!")
120
+
121
+ with gr.Row():
122
+ question = gr.Textbox(label="Ask a Kubernetes Question", lines=1)
123
+ answer = gr.Textbox(label="Answer", interactive=False)
124
+
125
+ submit_button = gr.Button("Get Answer")
126
+
127
+ submit_button.click(fn=get_k8s_answer, inputs=question, outputs=answer)
128
+
129
+ demo.launch()