guirnd commited on
Commit
fbbb97a
·
verified ·
1 Parent(s): f8ba7a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -181
app.py CHANGED
@@ -1,182 +1,182 @@
1
- import os
2
- import re
3
- import soundfile as sf
4
- import torch
5
- import torchaudio
6
- import torchaudio.transforms as T
7
- from datasets import load_dataset
8
- from transformers import WhisperForConditionalGeneration, WhisperProcessor, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoModel
9
- from langchain.document_loaders import PyPDFLoader
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
- from langchain.vectorstores import FAISS
12
- from langchain.embeddings import HuggingFaceEmbeddings
13
- from langchain.prompts import PromptTemplate
14
- from langchain.chains import LLMChain, StuffDocumentsChain, RetrievalQA
15
- from langchain.llms import LlamaCpp
16
- import gradio as gr
17
-
18
- class PDFProcessor:
19
- def __init__(self, pdf_path):
20
- self.pdf_path = pdf_path
21
-
22
- def load_and_split_pdf(self):
23
- loader = PyPDFLoader(self.pdf_path)
24
- documents = loader.load()
25
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
26
- docs = text_splitter.split_documents(documents)
27
- return docs
28
-
29
- class FAISSManager:
30
- def __init__(self):
31
- self.vectorstore_cache = {}
32
-
33
- def build_faiss_index(self, docs):
34
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
- vectorstore = FAISS.from_documents(docs, embeddings)
36
- return vectorstore
37
-
38
- def save_faiss_index(self, vectorstore, file_path):
39
- vectorstore.save_local(file_path)
40
- print(f"Vectorstore saved to {file_path}")
41
-
42
- def load_faiss_index(self, file_path):
43
- if not os.path.exists(f"{file_path}/index.faiss") or not os.path.exists(f"{file_path}/index.pkl"):
44
- raise FileNotFoundError(f"Could not find FAISS index or metadata files in {file_path}")
45
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
46
- vectorstore = FAISS.load_local(file_path, embeddings, allow_dangerous_deserialization=True)
47
- print(f"Vectorstore loaded from {file_path}")
48
- return vectorstore
49
-
50
- def build_faiss_index_with_cache_and_file(self, pdf_processor, vectorstore_path):
51
- if os.path.exists(vectorstore_path):
52
- print(f"Loading vectorstore from file {vectorstore_path}")
53
- return self.load_faiss_index(vectorstore_path)
54
-
55
- print(f"Building new vectorstore for {pdf_processor.pdf_path}")
56
- docs = pdf_processor.load_and_split_pdf()
57
- vectorstore = self.build_faiss_index(docs)
58
- self.save_faiss_index(vectorstore, vectorstore_path)
59
- return vectorstore
60
-
61
- class LLMChainFactory:
62
- def __init__(self, prompt_template):
63
- self.prompt_template = prompt_template
64
-
65
- def create_llm_chain(self, llm, max_tokens=80):
66
- prompt = PromptTemplate(template=self.prompt_template, input_variables=["documents", "question"])
67
- llm_chain = LLMChain(llm=llm, prompt=prompt)
68
- llm_chain.llm.max_tokens = max_tokens
69
- combine_documents_chain = StuffDocumentsChain(
70
- llm_chain=llm_chain,
71
- document_variable_name="documents"
72
- )
73
- return combine_documents_chain
74
-
75
- class LLMManager:
76
- def __init__(self, model_path):
77
- self.llm = LlamaCpp(model_path=model_path)
78
- self.llm.max_tokens = 80
79
-
80
- def create_rag_chain(self, llm_chain_factory, vectorstore):
81
- retriever = vectorstore.as_retriever()
82
- combine_documents_chain = llm_chain_factory.create_llm_chain(self.llm)
83
- qa_chain = RetrievalQA(combine_documents_chain=combine_documents_chain, retriever=retriever)
84
- return qa_chain
85
-
86
- def main_rag_pipeline(self, pdf_processor, query, vectorstore_manager, vectorstore_file):
87
- vectorstore = vectorstore_manager.build_faiss_index_with_cache_and_file(pdf_processor, vectorstore_file)
88
- llm_chain_factory = LLMChainFactory(prompt_template="""You are a helpful AI. Based on the context below, answer the question politely.
89
- Context: {documents}
90
- Question: {question}
91
- Answer:""")
92
- rag_chain = self.create_rag_chain(llm_chain_factory, vectorstore)
93
- result = rag_chain.run(query)
94
- return result
95
-
96
- class WhisperManager:
97
- def __init__(self):
98
- self.model_id = "openai/whisper-small"
99
- self.whisper_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
100
- self.whisper_processor = WhisperProcessor.from_pretrained(self.model_id)
101
- self.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
102
-
103
- def transcribe_speech(self, filepath):
104
- if not os.path.isfile(filepath):
105
- raise ValueError(f"Invalid file path: {filepath}")
106
- waveform, sample_rate = torchaudio.load(filepath)
107
- target_sample_rate = 16000
108
- if sample_rate != target_sample_rate:
109
- resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
110
- waveform = resampler(waveform)
111
- input_features = self.whisper_processor(waveform.squeeze(), sampling_rate=target_sample_rate, return_tensors="pt").input_features
112
- generated_ids = self.whisper_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
113
- transcribed_text = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
114
- cleaned_text = re.sub(r"<[^>]*>", "", transcribed_text).strip()
115
- return cleaned_text
116
-
117
- class SpeechT5Manager:
118
- def __init__(self):
119
- self.SpeechT5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
120
- self.SpeechT5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
121
- self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
122
- self.speaker_embedding_model = AutoModel.from_pretrained("microsoft/speecht5_vc")
123
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
124
- self.pretrained_speaker_embeddings = torch.tensor(embeddings_dataset[7000]["xvector"]).unsqueeze(0)
125
-
126
- def text_to_speech(self, text, output_file="output_speechT5.wav"):
127
- inputs = self.SpeechT5_processor(text=[text], return_tensors="pt")
128
- speech = self.SpeechT5_model.generate_speech(inputs["input_ids"], self.pretrained_speaker_embeddings, vocoder=self.vocoder)
129
- sf.write(output_file, speech.numpy(), 16000)
130
- return output_file
131
-
132
- # --- Gradio Interface ---
133
- def asr_to_text(audio_file):
134
- transcribed_text = whisper_manager.transcribe_speech(audio_file)
135
- return transcribed_text
136
-
137
- def process_with_llm_and_tts(transcribed_text):
138
- response_text = llm_manager.main_rag_pipeline(pdf_processor, transcribed_text, vectorstore_manager, vectorstore_file)
139
- audio_output = speech_manager.text_to_speech(response_text)
140
- return response_text, audio_output
141
-
142
- # Instantiate Managers
143
- pdf_processor = PDFProcessor('./files/LawsoftheGame2024_25.pdf')
144
- vectorstore_manager = FAISSManager()
145
- llm_manager = LLMManager(model_path="./files/mistral-7b-instruct-v0.2.Q2_K.gguf")
146
- whisper_manager = WhisperManager()
147
- speech_manager = SpeechT5Manager()
148
- vectorstore_file = "./vectorstore_faiss"
149
-
150
- # Define Gradio Interface
151
- with gr.Blocks() as demo:
152
- gr.Markdown("<h1 style='text-align: center;'>:robot: RAG Powered Voice Assistant :robot:</h1>")
153
- gr.Markdown("<h1 style='text-align: center;'>Ask me anything about the rules of Football!</h1>")
154
-
155
- # Step 1: Audio input and ASR output
156
- with gr.Row():
157
- audio_input = gr.Audio(type="filepath", label="Speak your question")
158
- asr_output = gr.Textbox(label="ASR Output (Edit if necessary)", interactive=True)
159
-
160
- # Button to process audio (ASR)
161
- asr_button = gr.Button("1 - Transform Voice to Text")
162
-
163
- # Step 2: LLM Response and TTS output
164
- with gr.Row():
165
- llm_response = gr.Textbox(label="LLM Response")
166
- tts_audio_output = gr.Audio(label="TTS Audio")
167
-
168
- # Button to process text with LLM
169
- llm_button = gr.Button("2 - Submit Question")
170
-
171
- # When ASR button is clicked, the audio is transcribed
172
- asr_button.click(fn=asr_to_text, inputs=audio_input, outputs=asr_output)
173
-
174
- # When LLM button is clicked, the text is processed with the LLM and converted to speech
175
- llm_button.click(fn=process_with_llm_and_tts, inputs=asr_output, outputs=[llm_response, tts_audio_output])
176
-
177
- # Disclaimer
178
- gr.Markdown(
179
- "<p style='text-align: center; color: gray;'>Disclaimer: This application was developed solely for educational purposes to demonstrate AI capabilities and should not be used as a source of information or for any other purpose.</p>"
180
- )
181
-
182
  demo.launch(debug=True)
 
1
+ import os
2
+ import re
3
+ import soundfile as sf
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.transforms as T
7
+ from datasets import load_dataset
8
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, AutoModel
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.vectorstores import FAISS
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains import LLMChain, StuffDocumentsChain, RetrievalQA
15
+ from langchain.llms import LlamaCpp
16
+ import gradio as gr
17
+
18
+ class PDFProcessor:
19
+ def __init__(self, pdf_path):
20
+ self.pdf_path = pdf_path
21
+
22
+ def load_and_split_pdf(self):
23
+ loader = PyPDFLoader(self.pdf_path)
24
+ documents = loader.load()
25
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=20)
26
+ docs = text_splitter.split_documents(documents)
27
+ return docs
28
+
29
+ class FAISSManager:
30
+ def __init__(self):
31
+ self.vectorstore_cache = {}
32
+
33
+ def build_faiss_index(self, docs):
34
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
+ vectorstore = FAISS.from_documents(docs, embeddings)
36
+ return vectorstore
37
+
38
+ def save_faiss_index(self, vectorstore, file_path):
39
+ vectorstore.save_local(file_path)
40
+ print(f"Vectorstore saved to {file_path}")
41
+
42
+ def load_faiss_index(self, file_path):
43
+ if not os.path.exists(f"{file_path}/index.faiss") or not os.path.exists(f"{file_path}/index.pkl"):
44
+ raise FileNotFoundError(f"Could not find FAISS index or metadata files in {file_path}")
45
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
46
+ vectorstore = FAISS.load_local(file_path, embeddings, allow_dangerous_deserialization=True)
47
+ print(f"Vectorstore loaded from {file_path}")
48
+ return vectorstore
49
+
50
+ def build_faiss_index_with_cache_and_file(self, pdf_processor, vectorstore_path):
51
+ if os.path.exists(vectorstore_path):
52
+ print(f"Loading vectorstore from file {vectorstore_path}")
53
+ return self.load_faiss_index(vectorstore_path)
54
+
55
+ print(f"Building new vectorstore for {pdf_processor.pdf_path}")
56
+ docs = pdf_processor.load_and_split_pdf()
57
+ vectorstore = self.build_faiss_index(docs)
58
+ self.save_faiss_index(vectorstore, vectorstore_path)
59
+ return vectorstore
60
+
61
+ class LLMChainFactory:
62
+ def __init__(self, prompt_template):
63
+ self.prompt_template = prompt_template
64
+
65
+ def create_llm_chain(self, llm, max_tokens=80):
66
+ prompt = PromptTemplate(template=self.prompt_template, input_variables=["documents", "question"])
67
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
68
+ llm_chain.llm.max_tokens = max_tokens
69
+ combine_documents_chain = StuffDocumentsChain(
70
+ llm_chain=llm_chain,
71
+ document_variable_name="documents"
72
+ )
73
+ return combine_documents_chain
74
+
75
+ class LLMManager:
76
+ def __init__(self, model_path):
77
+ self.llm = LlamaCpp(model_path=model_path)
78
+ self.llm.max_tokens = 80
79
+
80
+ def create_rag_chain(self, llm_chain_factory, vectorstore):
81
+ retriever = vectorstore.as_retriever()
82
+ combine_documents_chain = llm_chain_factory.create_llm_chain(self.llm)
83
+ qa_chain = RetrievalQA(combine_documents_chain=combine_documents_chain, retriever=retriever)
84
+ return qa_chain
85
+
86
+ def main_rag_pipeline(self, pdf_processor, query, vectorstore_manager, vectorstore_file):
87
+ vectorstore = vectorstore_manager.build_faiss_index_with_cache_and_file(pdf_processor, vectorstore_file)
88
+ llm_chain_factory = LLMChainFactory(prompt_template="""You are a helpful AI. Based on the context below, answer the question politely.
89
+ Context: {documents}
90
+ Question: {question}
91
+ Answer:""")
92
+ rag_chain = self.create_rag_chain(llm_chain_factory, vectorstore)
93
+ result = rag_chain.run(query)
94
+ return result
95
+
96
+ class WhisperManager:
97
+ def __init__(self):
98
+ self.model_id = "openai/whisper-small"
99
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(self.model_id)
100
+ self.whisper_processor = WhisperProcessor.from_pretrained(self.model_id)
101
+ self.forced_decoder_ids = self.whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
102
+
103
+ def transcribe_speech(self, filepath):
104
+ if not os.path.isfile(filepath):
105
+ raise ValueError(f"Invalid file path: {filepath}")
106
+ waveform, sample_rate = torchaudio.load(filepath)
107
+ target_sample_rate = 16000
108
+ if sample_rate != target_sample_rate:
109
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
110
+ waveform = resampler(waveform)
111
+ input_features = self.whisper_processor(waveform.squeeze(), sampling_rate=target_sample_rate, return_tensors="pt").input_features
112
+ generated_ids = self.whisper_model.generate(input_features, forced_decoder_ids=self.forced_decoder_ids)
113
+ transcribed_text = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
114
+ cleaned_text = re.sub(r"<[^>]*>", "", transcribed_text).strip()
115
+ return cleaned_text
116
+
117
+ class SpeechT5Manager:
118
+ def __init__(self):
119
+ self.SpeechT5_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
120
+ self.SpeechT5_model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
121
+ self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
122
+ self.speaker_embedding_model = AutoModel.from_pretrained("microsoft/speecht5_vc")
123
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
124
+ self.pretrained_speaker_embeddings = torch.tensor(embeddings_dataset[7000]["xvector"]).unsqueeze(0)
125
+
126
+ def text_to_speech(self, text, output_file="output_speechT5.wav"):
127
+ inputs = self.SpeechT5_processor(text=[text], return_tensors="pt")
128
+ speech = self.SpeechT5_model.generate_speech(inputs["input_ids"], self.pretrained_speaker_embeddings, vocoder=self.vocoder)
129
+ sf.write(output_file, speech.numpy(), 16000)
130
+ return output_file
131
+
132
+ # --- Gradio Interface ---
133
+ def asr_to_text(audio_file):
134
+ transcribed_text = whisper_manager.transcribe_speech(audio_file)
135
+ return transcribed_text
136
+
137
+ def process_with_llm_and_tts(transcribed_text):
138
+ response_text = llm_manager.main_rag_pipeline(pdf_processor, transcribed_text, vectorstore_manager, vectorstore_file)
139
+ audio_output = speech_manager.text_to_speech(response_text)
140
+ return response_text, audio_output
141
+
142
+ # Instantiate Managers
143
+ pdf_processor = PDFProcessor('./files/LawsoftheGame2024_25.pdf')
144
+ vectorstore_manager = FAISSManager()
145
+ llm_manager = LLMManager(model_path="./files/mistral-7b-instruct-v0.2.Q2_K.gguf")
146
+ whisper_manager = WhisperManager()
147
+ speech_manager = SpeechT5Manager()
148
+ vectorstore_file = "./vectorstore_faiss"
149
+
150
+ # Define Gradio Interface
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("<h1 style='text-align: center;'>RAG Powered Voice Assistant</h1>") #removed emojis
153
+ gr.Markdown("<h1 style='text-align: center;'>Ask me anything about the rules of Football!</h1>")
154
+
155
+ # Step 1: Audio input and ASR output
156
+ with gr.Row():
157
+ audio_input = gr.Audio(type="filepath", label="Speak your question")
158
+ asr_output = gr.Textbox(label="ASR Output (Edit if necessary)", interactive=True)
159
+
160
+ # Button to process audio (ASR)
161
+ asr_button = gr.Button("1 - Transform Voice to Text")
162
+
163
+ # Step 2: LLM Response and TTS output
164
+ with gr.Row():
165
+ llm_response = gr.Textbox(label="LLM Response")
166
+ tts_audio_output = gr.Audio(label="TTS Audio")
167
+
168
+ # Button to process text with LLM
169
+ llm_button = gr.Button("2 - Submit Question")
170
+
171
+ # When ASR button is clicked, the audio is transcribed
172
+ asr_button.click(fn=asr_to_text, inputs=audio_input, outputs=asr_output)
173
+
174
+ # When LLM button is clicked, the text is processed with the LLM and converted to speech
175
+ llm_button.click(fn=process_with_llm_and_tts, inputs=asr_output, outputs=[llm_response, tts_audio_output])
176
+
177
+ # Disclaimer
178
+ gr.Markdown(
179
+ "<p style='text-align: center; color: gray;'>Disclaimer: This application was developed solely for educational purposes to demonstrate AI capabilities and should not be used as a source of information or for any other purpose.</p>"
180
+ )
181
+
182
  demo.launch(debug=True)