YU-XI commited on
Commit
f538efb
1 Parent(s): af412b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -69
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import gradio as gr
3
- import asyncio
4
  from langchain_core.prompts import PromptTemplate
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_google_genai import ChatGoogleGenerativeAI
@@ -9,78 +8,63 @@ from langchain.chains.question_answering import load_qa_chain
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
 
12
- # Gemini PDF QA System
13
- async def initialize_gemini(file_path, question):
14
- genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
15
- model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
16
- prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
17
- not contained in the context, say "answer not available in context" \n\n
18
- Context: \n {context}?\n
19
- Question: \n {question} \n
20
- Answer:
21
- """
22
- prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
23
- if os.path.exists(file_path):
24
- pdf_loader = PyPDFLoader(file_path)
25
- pages = pdf_loader.load_and_split()
26
- context = "\n".join(str(page.page_content) for page in pages[:30])
27
- stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
28
- stuff_answer = await stuff_chain.acall({"input_documents": pages, "question": question, "context": context}, return_only_outputs=True)
29
- return stuff_answer['output_text']
30
- else:
31
- return "Error: Unable to process the document. Please ensure the PDF file is valid."
32
 
33
- # Improved Mistral Text Completion
34
- class MistralModel:
35
- def __init__(self):
36
- self.model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
37
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
38
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
39
- self.dtype = torch.bfloat16
40
- self.model = AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype=self.dtype, device_map=self.device)
41
 
42
- def generate_text(self, prompt, max_length=200):
43
- # Improve the prompt for better context
44
- enhanced_prompt = f"Question: {prompt}\n\nAnswer: Let's approach this step-by-step:\n1."
45
- inputs = self.tokenizer.encode(enhanced_prompt, return_tensors='pt').to(self.model.device)
46
-
47
- # Generate with more nuanced parameters
48
- outputs = self.model.generate(
49
- inputs,
50
- max_length=max_length,
51
- num_return_sequences=1,
52
- no_repeat_ngram_size=3,
53
- top_k=50,
54
- top_p=0.95,
55
- temperature=0.7,
56
- do_sample=True
57
- )
58
-
59
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
60
 
61
- mistral_model = MistralModel()
 
 
 
 
 
62
 
63
- # Combined function for both models
64
- async def process_input(file, question):
65
- gemini_answer = await initialize_gemini(file.name, question)
66
- mistral_answer = mistral_model.generate_text(question)
67
- return gemini_answer, mistral_answer
 
68
 
69
- # Gradio Interface
70
- with gr.Blocks() as demo:
71
- gr.Markdown("# Enhanced PDF Question Answering and Text Completion System")
72
-
73
- input_file = gr.File(label="Upload PDF File (Optional)")
74
- input_question = gr.Textbox(label="Ask a question or provide a prompt")
75
- process_button = gr.Button("Process")
76
-
77
- output_text_gemini = gr.Textbox(label="Answer - Gemini (PDF-based if file uploaded)")
78
- output_text_mistral = gr.Textbox(label="Answer - Mistral (General knowledge)")
79
 
80
- process_button.click(
81
- fn=process_input,
82
- inputs=[input_file, input_question],
83
- outputs=[output_text_gemini, output_text_mistral]
84
- )
85
 
86
- demo.launch()
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
 
3
  from langchain_core.prompts import PromptTemplate
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain_google_genai import ChatGoogleGenerativeAI
 
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
 
11
+ # Configure Gemini API
12
+ genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Load Mistral model
15
+ model_path = "nvidia/Mistral-NeMo-Minitron-8B-Base"
16
+ mistral_tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ dtype = torch.bfloat16
19
+ mistral_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
 
 
20
 
21
+ def initialize(file_path, question):
22
+ try:
23
+ model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.3)
24
+ prompt_template = """Answer the question as precise as possible using the provided context. If the answer is
25
+ not contained in the context, say "answer not available in context" \n\n
26
+ Context: \n {context}?\n
27
+ Question: \n {question} \n
28
+ Answer:
29
+ """
30
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
31
+ if os.path.exists(file_path):
32
+ pdf_loader = PyPDFLoader(file_path)
33
+ pages = pdf_loader.load_and_split()
34
+ context = "\n".join(str(page.page_content) for page in pages[:30])
35
+ stuff_chain = load_qa_chain(model, chain_type="stuff", prompt=prompt)
36
+ stuff_answer = stuff_chain({"input_documents": pages, "question": question, "context": context}, return_only_outputs=True)
37
+ gemini_answer = stuff_answer['output_text']
 
38
 
39
+ # Use Mistral model for additional text generation
40
+ mistral_prompt = f"Based on this answer: {gemini_answer}\nGenerate a follow-up question:"
41
+ mistral_inputs = mistral_tokenizer.encode(mistral_prompt, return_tensors='pt').to(device)
42
+ with torch.no_grad():
43
+ mistral_outputs = mistral_model.generate(mistral_inputs, max_length=50)
44
+ mistral_output = mistral_tokenizer.decode(mistral_outputs[0], skip_special_tokens=True)
45
 
46
+ combined_output = f"Gemini Answer: {gemini_answer}\n\nMistral Follow-up: {mistral_output}"
47
+ return combined_output
48
+ else:
49
+ return "Error: Unable to process the document. Please ensure the PDF file is valid."
50
+ except Exception as e:
51
+ return f"An error occurred: {str(e)}"
52
 
53
+ # Define Gradio Interface
54
+ input_file = gr.File(label="Upload PDF File")
55
+ input_question = gr.Textbox(label="Ask about the document")
56
+ output_text = gr.Textbox(label="Answer - Combined Gemini and Mistral")
 
 
 
 
 
 
57
 
58
+ def pdf_qa(file, question):
59
+ if file is None:
60
+ return "Please upload a PDF file first."
61
+ return initialize(file.name, question)
 
62
 
63
+ # Create Gradio Interface
64
+ gr.Interface(
65
+ fn=pdf_qa,
66
+ inputs=[input_file, input_question],
67
+ outputs=output_text,
68
+ title="RAG Knowledge Retrieval using Gemini API and Mistral Model",
69
+ description="Upload a PDF file and ask questions about the content."
70
+ ).launch()