Spaces:
Runtime error
Runtime error
File size: 4,861 Bytes
e5b3380 4551845 e5b3380 c630b36 e5b3380 4551845 c630b36 4551845 e5b3380 4551845 c630b36 4551845 c630b36 4551845 c630b36 4551845 e5b3380 4551845 e5b3380 c630b36 e5b3380 c630b36 e5b3380 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import gradio as gr
import whisper
from gtts import gTTS
import io
from groq import Groq
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"
# Initialize the Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Load the Whisper model
whisper_model = whisper.load_model("base") # You can choose other models like "small", "medium", "large"
# Initialize the tokenizer and model from the saved checkpoint for RAG
# Updated model loading code with disk offloading
# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"
# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"
# Ensure the offload folder exists
os.makedirs(offload_folder, exist_ok=True)
# Initialize the tokenizer
rag_tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
# Initialize empty weights context
with init_empty_weights():
# Load the model with meta tensors
rag_model = AutoModelForCausalLM.from_pretrained(
"himmeow/vi-gemma-2b-RAG",
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder=offload_folder
)
# Dispatch the model, ensuring correct device placement and weight loading
rag_model = load_checkpoint_and_dispatch(
rag_model,
"himmeow/vi-gemma-2b-RAG",
device_map="auto",
offload_folder=offload_folder,
offload_state_dict=True
)
# Ensure weights are properly tied if necessary
if hasattr(rag_model, 'tie_weights'):
rag_model.tie_weights()
# Use `to_empty()` to move the model out of the meta state correctly
rag_model = rag_model.to_empty()
# Move model to GPU if available
if torch.cuda.is_available():
rag_model = rag_model.to("cuda")
# Load PDF content
def load_pdf(pdf_path):
pdf_text = ""
with open(pdf_path, "rb") as file:
reader = PdfReader(file)
for page_num in range(len(reader.pages)):
page = reader.pages[page_num]
text = page.extract_text()
pdf_text += text + "\n"
return pdf_text
# Define the prompt format for the RAG model
prompt_template = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}
### Response:
{}
"""
# Function to process audio and generate a response using RAG and Groq
def process_audio_rag(file_path):
try:
# Load and transcribe the audio using Whisper
audio = whisper.load_audio(file_path)
result = whisper_model.transcribe(audio)
text = result["text"]
# Load the PDF content (update with your PDF path or pass it as an argument)
pdf_path = "/content/BN_Cotton.pdf"
pdf_text = load_pdf(pdf_path)
# Prepare the input data for the RAG model
query = text
input_text = prompt_template.format(pdf_text, query, " ")
# Encode the input text into input ids for RAG model
input_ids = rag_tokenizer(input_text, return_tensors="pt")
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# Generate text using the RAG model
outputs = rag_model.generate(
**input_ids,
max_new_tokens=500,
no_repeat_ngram_size=5
)
rag_response = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Generate a response using Groq if needed
chat_completion = client.chat.completions.create(
messages=[{"role": "user", "content": rag_response}],
model="llama3-8b-8192", # Replace with the correct model if necessary
)
response_message = chat_completion.choices[0].message.content.strip()
# Convert the response text to speech
tts = gTTS(response_message)
response_audio_io = io.BytesIO()
tts.write_to_fp(response_audio_io)
response_audio_io.seek(0)
# Save audio to a file to ensure it's generated correctly
with open("response.mp3", "wb") as audio_file:
audio_file.write(response_audio_io.getvalue())
# Return the response text and the path to the saved audio file
return response_message, "response.mp3"
except Exception as e:
return f"An error occurred: {e}", None
# Create a Gradio interface
iface = gr.Interface(
fn=process_audio_rag,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
live=True,
title="Agriculture Assistant"
)
# Launch the interface with the given title
iface.launch()
|