File size: 4,861 Bytes
e5b3380
 
 
 
 
 
 
 
 
4551845
 
e5b3380
c630b36
e5b3380
 
 
 
 
 
 
 
 
 
4551845
 
 
c630b36
 
 
 
4551845
 
 
 
 
 
 
e5b3380
4551845
c630b36
4551845
c630b36
4551845
 
 
 
 
 
 
c630b36
4551845
 
e5b3380
 
4551845
 
e5b3380
 
c630b36
 
 
 
 
 
 
 
e5b3380
c630b36
e5b3380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import gradio as gr
import whisper
from gtts import gTTS
import io
from groq import Groq
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch


# Set up environment variables
os.environ["GROQ_API_KEY"] = "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"

# Initialize the Groq client
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

# Load the Whisper model
whisper_model = whisper.load_model("base")  # You can choose other models like "small", "medium", "large"

# Initialize the tokenizer and model from the saved checkpoint for RAG
# Updated model loading code with disk offloading


# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"


# Specify the folder where offloaded model parts will be stored
offload_folder = "./offload"

# Ensure the offload folder exists
os.makedirs(offload_folder, exist_ok=True)

# Initialize the tokenizer
rag_tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")

# Initialize empty weights context
with init_empty_weights():
    # Load the model with meta tensors
    rag_model = AutoModelForCausalLM.from_pretrained(
        "himmeow/vi-gemma-2b-RAG",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        offload_folder=offload_folder
    )

# Dispatch the model, ensuring correct device placement and weight loading
rag_model = load_checkpoint_and_dispatch(
    rag_model,
    "himmeow/vi-gemma-2b-RAG",
    device_map="auto",
    offload_folder=offload_folder,
    offload_state_dict=True
)

# Ensure weights are properly tied if necessary
if hasattr(rag_model, 'tie_weights'):
    rag_model.tie_weights()

# Use `to_empty()` to move the model out of the meta state correctly
rag_model = rag_model.to_empty()

# Move model to GPU if available
if torch.cuda.is_available():
    rag_model = rag_model.to("cuda")

# Load PDF content
def load_pdf(pdf_path):
    pdf_text = ""
    with open(pdf_path, "rb") as file:
        reader = PdfReader(file)
        for page_num in range(len(reader.pages)):
            page = reader.pages[page_num]
            text = page.extract_text()
            pdf_text += text + "\n"
    return pdf_text

# Define the prompt format for the RAG model
prompt_template = """
### Instruction and Input:
Based on the following context/document:
{}
Please answer the question: {}

### Response:
{}
"""

# Function to process audio and generate a response using RAG and Groq
def process_audio_rag(file_path):
    try:
        # Load and transcribe the audio using Whisper
        audio = whisper.load_audio(file_path)
        result = whisper_model.transcribe(audio)
        text = result["text"]

        # Load the PDF content (update with your PDF path or pass it as an argument)
        pdf_path = "/content/BN_Cotton.pdf"
        pdf_text = load_pdf(pdf_path)

        # Prepare the input data for the RAG model
        query = text
        input_text = prompt_template.format(pdf_text, query, " ")

        # Encode the input text into input ids for RAG model
        input_ids = rag_tokenizer(input_text, return_tensors="pt")
        if torch.cuda.is_available():
            input_ids = input_ids.to("cuda")

        # Generate text using the RAG model
        outputs = rag_model.generate(
            **input_ids,
            max_new_tokens=500,
            no_repeat_ngram_size=5
        )
        rag_response = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Generate a response using Groq if needed
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": rag_response}],
            model="llama3-8b-8192",  # Replace with the correct model if necessary
        )
        response_message = chat_completion.choices[0].message.content.strip()

        # Convert the response text to speech
        tts = gTTS(response_message)
        response_audio_io = io.BytesIO()
        tts.write_to_fp(response_audio_io)
        response_audio_io.seek(0)

        # Save audio to a file to ensure it's generated correctly
        with open("response.mp3", "wb") as audio_file:
            audio_file.write(response_audio_io.getvalue())

        # Return the response text and the path to the saved audio file
        return response_message, "response.mp3"

    except Exception as e:
        return f"An error occurred: {e}", None

# Create a Gradio interface
iface = gr.Interface(
    fn=process_audio_rag,
    inputs=gr.Audio(type="filepath"),
    outputs=[gr.Textbox(label="Response Text"), gr.Audio(label="Response Audio")],
    live=True,
    title="Agriculture Assistant"
)

# Launch the interface with the given title
iface.launch()