Luciferalive commited on
Commit
4241b0f
·
verified ·
1 Parent(s): ca1e5d6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import re
4
+ import numpy as np
5
+ import pytesseract
6
+ from PIL import Image
7
+ from typing import List
8
+ from docx import Document
9
+ from sentence_transformers import SentenceTransformer
10
+ from langchain_community.vectorstores import Chroma
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
13
+ from groq import Groq
14
+ import gradio as gr
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ # Ensure the Tesseract OCR path is set correctly
18
+ pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
19
+
20
+ GROQ_API_KEY = "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"
21
+
22
+ def extract_text_from_doc(doc_content):
23
+ """Extract text from DOC file content."""
24
+ try:
25
+ doc = Document(io.BytesIO(doc_content))
26
+ extracted_text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
27
+ return extracted_text
28
+ except Exception as e:
29
+ print("Failed to extract text from DOC:", e)
30
+ return ""
31
+
32
+ def preprocess_text(text):
33
+ try:
34
+ text = text.replace('\n', ' ').replace('\r', ' ')
35
+ text = re.sub(r'[^\x00-\x7F]+', ' ', text)
36
+ text = text.lower()
37
+ text = re.sub(r'[^\w\s]', '', text)
38
+ text = re.sub(r'\s+', ' ', text).strip()
39
+ return text
40
+ except Exception as e:
41
+ print("Failed to preprocess text:", e)
42
+ return ""
43
+
44
+ def process_files(file_contents: List[bytes]):
45
+ all_text = ""
46
+ for file_content in file_contents:
47
+ extracted_text = extract_text_from_doc(file_content)
48
+ preprocessed_text = preprocess_text(extracted_text)
49
+ all_text += preprocessed_text + " "
50
+ return all_text
51
+
52
+ def compute_cosine_similarity_scores(query, retrieved_docs):
53
+ model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
54
+ query_embedding = model.encode(query, convert_to_tensor=True)
55
+ doc_embeddings = model.encode(retrieved_docs, convert_to_tensor=True)
56
+ cosine_scores = np.dot(doc_embeddings.cpu().numpy(), query_embedding.cpu().numpy().reshape(-1, 1))
57
+ readable_scores = [{"doc": doc, "score": float(score)} for doc, score in zip(retrieved_docs, cosine_scores.flatten())]
58
+ return readable_scores
59
+
60
+ def fetch_files_from_huggingface_space():
61
+ repo_id = "Luciferalive/goosev9"
62
+ file_names = [f"{i}.docx" for i in range(1, 22)]
63
+
64
+ file_contents = []
65
+ for file_name in file_names:
66
+ try:
67
+ file_path = hf_hub_download(repo_id, file_name)
68
+ with open(file_path, "rb") as f:
69
+ file_contents.append(f.read())
70
+ print(f"Successfully downloaded {file_name}")
71
+ except Exception as e:
72
+ print(f"Failed to download {file_name}: {e}")
73
+ return file_contents
74
+
75
+ def create_vector_store(all_text):
76
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
77
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
78
+ texts = text_splitter.split_text(all_text)
79
+ if not texts:
80
+ print("No text chunks created.")
81
+ return None
82
+
83
+ vector_store = Chroma.from_texts(texts, embeddings, collection_name="insurance_cosine")
84
+ print("Vector DB Successfully Created!")
85
+ return vector_store
86
+
87
+ def load_vector_store():
88
+ embeddings = SentenceTransformerEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
89
+ try:
90
+ db = Chroma(embedding_function=embeddings, collection_name="insurance_cosine")
91
+ print("Vector DB Successfully Loaded!")
92
+ return db
93
+ except Exception as e:
94
+ print("Failed to load Vector DB:", e)
95
+ return None
96
+
97
+ def answer_query_with_similarity(query):
98
+ try:
99
+ vector_store = load_vector_store()
100
+ if not vector_store:
101
+ file_contents = fetch_files_from_huggingface_space()
102
+ if not file_contents:
103
+ print("No files fetched from Hugging Face Space.")
104
+ return None
105
+
106
+ all_text = process_files(file_contents)
107
+ if not all_text.strip():
108
+ print("No text extracted from documents.")
109
+ return None
110
+
111
+ vector_store = create_vector_store(all_text)
112
+ if not vector_store:
113
+ print("Failed to create Vector DB.")
114
+ return None
115
+
116
+ docs = vector_store.similarity_search(query)
117
+ print(f"\n\nDocuments retrieved: {len(docs)}")
118
+
119
+ if not docs:
120
+ print("No documents match the query.")
121
+ return None
122
+
123
+ docs_content = [doc.page_content for doc in docs]
124
+ for i, content in enumerate(docs_content, start=1):
125
+ print(f"\nDocument {i}: {content[:500]}...")
126
+
127
+ cosine_similarity_scores = compute_cosine_similarity_scores(query, docs_content)
128
+ for score in cosine_similarity_scores:
129
+ print(f"\nDocument Score: {score['score']}")
130
+
131
+ all_docs_content = " ".join(docs_content)
132
+
133
+ client = Groq(api_key=GROQ_API_KEY)
134
+ template = """
135
+ ### [INST] Instruction:
136
+ You are an AI assistant named Goose. Your purpose is to provide accurate, relevant, and helpful information to users in a friendly, warm, and supportive manner, similar to ChatGPT. When responding to queries, please keep the following guidelines in mind:
137
+ - When someone says hi, or small talk, only respond in a sentence.
138
+ - Retrieve relevant information from your knowledge base to formulate accurate and informative responses.
139
+ - Always maintain a positive, friendly, and encouraging tone in your interactions with users.
140
+ - Strictly write crisp and clear answers, don't write unnecessary stuff.
141
+ - Only answer the asked question, don't hallucinate or print any pre-information.
142
+ - After providing the answer, always ask for any other help needed in the next paragraph.
143
+ - Writing in bullet format is our top preference.
144
+ Remember, your goal is to be a reliable, friendly, and supportive AI assistant that provides accurate information while creating a positive user experience, just like ChatGPT. Adapt your communication style to best suit each user's needs and preferences.
145
+ ### Docs: {docs}
146
+ ### Question: {question}
147
+ """
148
+
149
+ chat_completion = client.chat.completions.create(
150
+ messages=[
151
+ {
152
+ "role": "system",
153
+ "content": template.format(docs=all_docs_content, question=query)
154
+ },
155
+ {
156
+ "role": "user",
157
+ "content": query
158
+ }
159
+ ],
160
+ model="llama3-8b-8192",
161
+ )
162
+
163
+ answer = chat_completion.choices[0].message.content.strip()
164
+ return answer
165
+ except Exception as e:
166
+ print("An error occurred while getting the answer: ", str(e))
167
+ return None
168
+
169
+ def process_query(query):
170
+ try:
171
+ response = answer_query_with_similarity(query)
172
+ if response:
173
+ return "Answer: " + response
174
+ else:
175
+ return "No answer found."
176
+ except Exception as e:
177
+ print("An error occurred while getting the answer: ", str(e))
178
+ return "An error occurred: " + str(e)
179
+
180
+
181
+ # Set up the Gradio interface
182
+ iface = gr.Interface(
183
+ fn=process_query,
184
+ inputs=gr.Textbox(lines=7, label="Enter your question"),
185
+ outputs="text",
186
+ title="Goose AI Assistant",
187
+ description="Ask a question and get an answer from the AI assistant."
188
+ )
189
+
190
+ iface.launch()