seawolf2357 commited on
Commit
953debe
Β·
verified Β·
1 Parent(s): ccfa30a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from datasets import load_dataset
3
+
4
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
5
+
6
+ dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
7
+
8
+ data = dataset["train"]
9
+ data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
10
+
11
+ def search(query: str, k: int = 3 ):
12
+ """a function that embeds a new query and returns the most probable results"""
13
+ embedded_query = ST.encode(query) # embed new query
14
+ scores, retrieved_examples = data.get_nearest_examples( # retrieve results
15
+ "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
16
+ k=k # get only top k results
17
+ )
18
+ return scores, retrieved_examples
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
21
+ import torch
22
+
23
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
24
+
25
+ # use quantization to lower GPU usage
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
28
+ )
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ torch_dtype=torch.bfloat16,
34
+ device_map="auto",
35
+ quantization_config=bnb_config
36
+ )
37
+ terminators = [
38
+ tokenizer.eos_token_id,
39
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
+ ]
41
+ SYS_PROMPT = """You are an assistant for answering questions.
42
+ You are given the extracted parts of a long document and a question. Provide a conversational answer.
43
+ If you don't know the answer, just say "I do not know." Don't make up an answer."""
44
+
45
+ def format_prompt(prompt,retrieved_documents,k):
46
+ """using the retrieved documents we will prompt the model to generate our responses"""
47
+ PROMPT = f"Question:{prompt}\nContext:"
48
+ for idx in range(k) :
49
+ PROMPT+= f"{retrieved_documents['text'][idx]}\n"
50
+ return PROMPT
51
+
52
+ def generate(formatted_prompt):
53
+ formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
54
+ messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
55
+ # tell the model to generate
56
+ input_ids = tokenizer.apply_chat_template(
57
+ messages,
58
+ add_generation_prompt=True,
59
+ return_tensors="pt"
60
+ ).to(model.device)
61
+ outputs = model.generate(
62
+ input_ids,
63
+ max_new_tokens=1024,
64
+ eos_token_id=terminators,
65
+ do_sample=True,
66
+ temperature=0.6,
67
+ top_p=0.9,
68
+ )
69
+ response = outputs[0][input_ids.shape[-1]:]
70
+ return tokenizer.decode(response, skip_special_tokens=True)
71
+
72
+ def rag_chatbot(prompt:str,k:int=2):
73
+ scores , retrieved_documents = search(prompt, k)
74
+ formatted_prompt = format_prompt(prompt,retrieved_documents,k)
75
+ return generate(formatted_prompt)
76
+
77
+ rag_chatbot("what's anarchy ?", k = 2)