arif670 commited on
Commit
2795690
·
verified ·
1 Parent(s): d25c039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -47
app.py CHANGED
@@ -1,64 +1,45 @@
1
- import os
2
- import json
3
- import firebase_admin
4
- from firebase_admin import credentials, db
5
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
- from transformers import RagRetriever
7
 
 
 
 
 
8
  retriever = RagRetriever.from_pretrained(
9
  "facebook/rag-token-base",
10
  use_dummy_dataset=True,
11
  trust_remote_code=True
12
  )
13
 
14
- import gradio as gr
15
-
16
- # Initialize Firebase Admin SDK
17
- firebase_credential = os.getenv("FIREBASE_CREDENTIALS")
18
- if not firebase_credential:
19
- raise RuntimeError("FIREBASE_CREDENTIALS environment variable is not set.")
20
-
21
- # Save Firebase credentials to a temporary file
22
- with open("serviceAccountKey.json", "w") as f:
23
- f.write(firebase_credential)
24
-
25
- # Initialize Firebase App
26
- cred = credentials.Certificate("serviceAccountKey.json")
27
- firebase_admin.initialize_app(cred, {"databaseURL": "https://your-database-name.firebaseio.com/"})
28
-
29
- # Load the RAG model, tokenizer, and retriever
30
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
31
- retriever = RagRetriever.from_pretrained("facebook/rag-token-base", use_dummy_dataset=True, trust_remote_code=True)
 
32
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base")
33
 
34
- # Function to generate answers using the RAG model
35
- def generate_answer(question, context=""):
36
- # Tokenize the question and context
37
  inputs = tokenizer(question, return_tensors="pt")
38
 
39
- # Retrieve relevant documents (dummy dataset for this example)
40
- # In a real-world case, you would provide a proper knowledge base or corpus
41
- retrieved_docs = retriever(question=question, input_ids=inputs["input_ids"])
42
 
43
- # Generate the answer using the RAG model
44
- outputs = model.generate(input_ids=inputs["input_ids"],
45
- context_input_ids=retrieved_docs["context_input_ids"])
 
 
46
 
47
- # Decode the generated answer
48
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
- return answer
50
-
51
- # Gradio interface function
52
- def dashboard(question):
53
- # Generate the answer from the RAG model
54
- answer = generate_answer(question)
55
  return answer
56
 
57
- # Gradio Interface Setup
58
- interface = gr.Interface(fn=dashboard, inputs="text", outputs="text")
59
-
60
- # Launch the Gradio app
61
  if __name__ == "__main__":
62
- interface.launch()
63
-
64
-
 
 
 
 
1
+ import torch
2
+ from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
3
+ from datasets import load_dataset
 
 
 
4
 
5
+ # Step 1: Load the dataset with the trust_remote_code flag enabled
6
+ dataset = load_dataset("wiki_dpr", trust_remote_code=True)
7
+
8
+ # Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True
9
  retriever = RagRetriever.from_pretrained(
10
  "facebook/rag-token-base",
11
  use_dummy_dataset=True,
12
  trust_remote_code=True
13
  )
14
 
15
+ # Step 3: Load the tokenizer for the RAG model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
17
+
18
+ # Step 4: Initialize the RAG model
19
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-base")
20
 
21
+ # Step 5: Define a function to generate an answer using the retriever and model
22
+ def generate_answer(question):
23
+ # Tokenize the question
24
  inputs = tokenizer(question, return_tensors="pt")
25
 
26
+ # Retrieve relevant documents using the retriever
27
+ input_ids = inputs["input_ids"]
28
+ retrieved_doc_ids = retriever.retrieve(input_ids)
29
 
30
+ # Use the model to generate an answer based on the retrieved documents
31
+ generated_ids = model.generate(input_ids, context_input_ids=retrieved_doc_ids["context_input_ids"])
32
+
33
+ # Decode the generated answer back to text
34
+ answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
 
 
 
 
 
 
 
 
 
36
  return answer
37
 
38
+ # Step 6: Example usage
 
 
 
39
  if __name__ == "__main__":
40
+ question = "Who was the first president of the United States?"
41
+ print(f"Question: {question}")
42
+
43
+ # Generate and print the answer
44
+ answer = generate_answer(question)
45
+ print(f"Answer: {answer}")