guirnd commited on
Commit
f8ba7a5
·
verified ·
1 Parent(s): 3bf0978

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +182 -0
  2. packages.txt +1 -0
  3. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import soundfile as sf
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.transforms as T
7
+ from datasets import load_dataset
8
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoModel
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.vectorstores import FAISS
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains import LLMChain, StuffDocumentsChain, RetrievalQA
15
+ from langchain.llms import LlamaCpp
16
+ import gradio as gr
17
+
18
+ class PDFProcessor:
19
+ def __init__(self, pdf_path):
20
+ self.pdf_path = pdf_path
21
+
22
+ def load_and_split_pdf(self):
23
+ loader = PyPDFLoader(self.pdf_path)
24
+ documents = loader.load()
25
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
26
+ docs = text_splitter.split_documents(documents)
27
+ return docs
28
+
29
+ class FAISSManager:
30
+ def __init__(self):
31
+ self.vectorstore_cache = {}
32
+
33
+ def build_faiss_index(self, docs):
34
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
+ vectorstore = FAISS.from_documents(docs, embeddings)
36
+ return vectorstore
37
+
38
+ def save_faiss_index(self, vectorstore, file_path):
39
+ vectorstore.save_local(file_path)
40
+ print(f"Vectorstore saved to {file_path}")
41
+
42
+ def load_faiss_index(self, file_path):
43
+ if not os.path.exists(f"{file_path}/index.faiss") or not os.path.exists(f"{file_path}/index.pkl"):
44
+ raise FileNotFoundError(f"Could not find FAISS index or metadata files in {file_path}")
45
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
46
+ vectorstore = FAISS.load_local(file_path, embeddings, allow_dangerous_deserialization=True)
47
+ print(f"Vectorstore loaded from {file_path}")
48
+ return vectorstore
49
+
50
+ def build_faiss_index_with_cache_and_file(self, pdf_processor, vectorstore_path):
51
+ if os.path.exists(vectorstore_path):
52
+ print(f"Loading vectorstore from file {vectorstore_path}")
53
+ return self.load_faiss_index(vectorstore_path)
54
+
55
+ print(f"Building new vectorstore for {pdf_processor.pdf_path}")
56
+ docs = pdf_processor.load_and_split_pdf()
57
+ vectorstore = self.build_faiss_index(docs)
58
+ self.save_faiss_index(vectorstore, vectorstore_path)
59
+ return vectorstore
60
+
61
+ class LLMChainFactory:
62
+ def __init__(self, prompt_template):
63
+ self.prompt_template = prompt_template
64
+
65
+ def create_llm_chain(self, llm, max_tokens=80):
66
+ prompt = PromptTemplate(template=self.prompt_template, input_variables=["documents", "question"])
67
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
68
+ llm_chain.llm.max_tokens = max_tokens
69
+ combine_documents_chain = StuffDocumentsChain(
70
+ llm_chain=llm_chain,
71
+ document_variable_name="documents"
72
+ )
73
+ return combine_documents_chain
74
+
75
+ class LLMManager:
76
+ def __init__(self, model_path):
77
+ self.llm = LlamaCpp(model_path=model_path)
78
+ self.llm.max_tokens = 80
79
+
80
+ def create_rag_chain(self, llm_chain_factory, vectorstore):
81
+ retriever = vectorstore.as_retriever()
82
+ combine_documents_chain = llm_chain_factory.create_llm_chain(self.llm)
83
+ qa_chain = RetrievalQA(combine_documents_chain=combine_documents_chain, retriever=retriever)
84
+ return qa_chain
85
+
86
+ def main_rag_pipeline(self, pdf_processor, query, vectorstore_manager, vectorstore_file):
87
+ vectorstore = vectorstore_manager.build_faiss_index_with_cache_and_file(pdf_processor, vectorstore_file)
88
+ llm_chain_factory = LLMChainFactory(prompt_template="""You are a helpful AI. Based on the context below, answer the question politely.
89
+ Context: {documents}
90
+ Question: {question}
91
+ Answer:""")
92
+ rag_chain = self.create_rag_chain(llm_chain_factory, vectorstore)
93
+ result = rag_chain.run(query)
94
+ return result
95
+
96
+ class WhisperManager:
97
+ def __init__(self):
98
+ self.model_id = "openai/whisper-small"
99
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
100
+ self.whisper_processor = WhisperProcessor.from_pretrained(self.model_id)
101
+ self.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
102
+
103
+ def transcribe_speech(self, filepath):
104
+ if not os.path.isfile(filepath):
105
+ raise ValueError(f"Invalid file path: {filepath}")
106
+ waveform, sample_rate = torchaudio.load(filepath)
107
+ target_sample_rate = 16000
108
+ if sample_rate != target_sample_rate:
109
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
110
+ waveform = resampler(waveform)
111
+ input_features = self.whisper_processor(waveform.squeeze(), sampling_rate=target_sample_rate, return_tensors="pt").input_features
112
+ generated_ids = self.whisper_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
113
+ transcribed_text = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
114
+ cleaned_text = re.sub(r"<[^>]*>", "", transcribed_text).strip()
115
+ return cleaned_text
116
+
117
+ class SpeechT5Manager:
118
+ def __init__(self):
119
+ self.SpeechT5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
120
+ self.SpeechT5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
121
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
122
+ self.speaker_embedding_model = AutoModel.from_pretrained("microsoft/speecht5_vc")
123
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
124
+ self.pretrained_speaker_embeddings = torch.tensor(embeddings_dataset[7000]["xvector"]).unsqueeze(0)
125
+
126
+ def text_to_speech(self, text, output_file="output_speechT5.wav"):
127
+ inputs = self.SpeechT5_processor(text=[text], return_tensors="pt")
128
+ speech = self.SpeechT5_model.generate_speech(inputs["input_ids"], self.pretrained_speaker_embeddings, vocoder=self.vocoder)
129
+ sf.write(output_file, speech.numpy(), 16000)
130
+ return output_file
131
+
132
+ # --- Gradio Interface ---
133
+ def asr_to_text(audio_file):
134
+ transcribed_text = whisper_manager.transcribe_speech(audio_file)
135
+ return transcribed_text
136
+
137
+ def process_with_llm_and_tts(transcribed_text):
138
+ response_text = llm_manager.main_rag_pipeline(pdf_processor, transcribed_text, vectorstore_manager, vectorstore_file)
139
+ audio_output = speech_manager.text_to_speech(response_text)
140
+ return response_text, audio_output
141
+
142
+ # Instantiate Managers
143
+ pdf_processor = PDFProcessor('./files/LawsoftheGame2024_25.pdf')
144
+ vectorstore_manager = FAISSManager()
145
+ llm_manager = LLMManager(model_path="./files/mistral-7b-instruct-v0.2.Q2_K.gguf")
146
+ whisper_manager = WhisperManager()
147
+ speech_manager = SpeechT5Manager()
148
+ vectorstore_file = "./vectorstore_faiss"
149
+
150
+ # Define Gradio Interface
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("<h1 style='text-align: center;'>:robot: RAG Powered Voice Assistant :robot:</h1>")
153
+ gr.Markdown("<h1 style='text-align: center;'>Ask me anything about the rules of Football!</h1>")
154
+
155
+ # Step 1: Audio input and ASR output
156
+ with gr.Row():
157
+ audio_input = gr.Audio(type="filepath", label="Speak your question")
158
+ asr_output = gr.Textbox(label="ASR Output (Edit if necessary)", interactive=True)
159
+
160
+ # Button to process audio (ASR)
161
+ asr_button = gr.Button("1 - Transform Voice to Text")
162
+
163
+ # Step 2: LLM Response and TTS output
164
+ with gr.Row():
165
+ llm_response = gr.Textbox(label="LLM Response")
166
+ tts_audio_output = gr.Audio(label="TTS Audio")
167
+
168
+ # Button to process text with LLM
169
+ llm_button = gr.Button("2 - Submit Question")
170
+
171
+ # When ASR button is clicked, the audio is transcribed
172
+ asr_button.click(fn=asr_to_text, inputs=audio_input, outputs=asr_output)
173
+
174
+ # When LLM button is clicked, the text is processed with the LLM and converted to speech
175
+ llm_button.click(fn=process_with_llm_and_tts, inputs=asr_output, outputs=[llm_response, tts_audio_output])
176
+
177
+ # Disclaimer
178
+ gr.Markdown(
179
+ "<p style='text-align: center; color: gray;'>Disclaimer: This application was developed solely for educational purposes to demonstrate AI capabilities and should not be used as a source of information or for any other purpose.</p>"
180
+ )
181
+
182
+ demo.launch(debug=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ faiss-cpu
4
+ llama-cpp-python
5
+ PyPDF2
6
+ pypdf
7
+ sentence-transformers
8
+ datasets
9
+ torch
10
+ torchaudio
11
+ sentencepiece
12
+ soundfile
13
+ gradio