not-lain commited on
Commit
eaca477
β€’
1 Parent(s): 0b808a5

🌘wπŸŒ–

Browse files
Files changed (2) hide show
  1. app.py +28 -52
  2. requirements.txt +1 -1
app.py CHANGED
@@ -7,7 +7,7 @@ import spaces
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import torch
9
  from threading import Thread
10
- from ragatouille import RAGPretrainedModel
11
  from datasets import load_dataset
12
 
13
 
@@ -18,72 +18,48 @@ model = AutoModelForCausalLM.from_pretrained(
18
  torch_dtype=torch.float16,
19
  token=token,
20
  )
21
- tok = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
22
  device = torch.device("cuda")
23
  model = model.to(device)
24
- RAG = RAGPretrainedModel.from_pretrained("mixedbread-ai/mxbai-colbert-v1")
25
-
26
  # prepare data
27
  # since data is too big we will only select the first 3K lines
28
 
29
- dataset = load_dataset(
30
- "wikimedia/wikipedia", "20231101.en", split="train", streaming=True
31
- )
32
- # init data
33
- data = Dataset.from_dict({})
34
- i = 0
35
- for i, entry in enumerate(dataset):
36
- # each entry has the following columns
37
- # ['id', 'url', 'title', 'text']
38
- data = data.add_item(entry)
39
- if i == 3000:
40
- break
41
- # free memory
42
- del dataset # we keep data
43
 
44
- # index data
45
- documents = data["text"]
46
- RAG.index(documents, index_name="wikipedia", use_faiss=True)
47
- # free memory
48
- del documents
49
 
50
- def search(query, k: int = 5):
51
- results = RAG.search(query, k=k)
52
- # results are ordered according to their score
53
- # results has the following keys
54
- #
55
- # {'content' : 'retrieved content'
56
- # 'score' : score[float]
57
- # 'rank' : "results are sorted using score and each is given a rank, also can be called place, 1 2 3 4 ..."
58
- # 'document_id' : "no clue man i just got here"
59
- # 'passage_id' : "or original row number"
60
- # }
61
- #
62
- return [result["passage_id"] for result in results]
63
 
64
 
65
- def prepare_prompt(query, indexes,data = data):
66
  prompt = (
67
  f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
68
  )
69
- titles = []
70
  urls = []
71
- for i in indexes:
72
- title = entry["title"][i]
73
- text = entry["text"][i]
74
- url = entry["url"][i]
75
- titles.append(title)
76
- urls.append(url)
77
- prompt += f"Title: {title}, Text: {text}\n"
78
- return prompt, (titles,urls)
79
 
80
 
81
  @spaces.GPU
82
  def talk(message, history):
83
- indexes = search(message)
84
- message,metadata = prepare_prompt(message, indexes)
85
  resources = "\nRESOURCES:\n"
86
- for title,url in metadata:
87
  resources += f"[{title}]({url}), "
88
  chat = []
89
  for item in history:
@@ -92,11 +68,11 @@ def talk(message, history):
92
  cleaned_past = item[1].split("\nRESOURCES:\n")[0]
93
  chat.append({"role": "assistant", "content": cleaned_past})
94
  chat.append({"role": "user", "content": message})
95
- messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
96
  # Tokenize the messages string
97
- model_inputs = tok([messages], return_tensors="pt").to(device)
98
  streamer = TextIteratorStreamer(
99
- tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
100
  )
101
  generate_kwargs = dict(
102
  model_inputs,
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import torch
9
  from threading import Thread
10
+ from sentence_transformers import SentenceTransformer
11
  from datasets import load_dataset
12
 
13
 
 
18
  torch_dtype=torch.float16,
19
  token=token,
20
  )
21
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token)
22
  device = torch.device("cuda")
23
  model = model.to(device)
24
+ RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
25
+ TOP_K = 3
26
  # prepare data
27
  # since data is too big we will only select the first 3K lines
28
 
29
+ data = load_dataset("not-lain/wikipedia-small-3000-embedded", subset="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # index dataset
32
+ data.add_faiss_index("embedding", device=1)
 
 
 
33
 
34
+ @spaces.GPU
35
+ def search(query: str, k: int = TOP_K):
36
+ embedded_query = model.encode(query)
37
+ scores, retrieved_examples = data.get_nearest_examples(
38
+ "embedding", embedded_query, k=k
39
+ )
40
+ return retrieved_examples
 
 
 
 
 
 
41
 
42
 
43
+ def prepare_prompt(query, retrieved_examples):
44
  prompt = (
45
  f"Query: {query}\nContinue to answer the query by using the Search Results:\n"
46
  )
 
47
  urls = []
48
+ titles = retrieved_examples["title"][::-1]
49
+ texts = retrieved_examples["text"][::-1]
50
+ urls = retrieved_examples["url"][::-1]
51
+ titles = titles[::-1]
52
+ for i in range(TOP_K):
53
+ prompt += f"Title: {titles[i]}, Text: {texts[i]}\n"
54
+ return prompt, (titles, urls)
 
55
 
56
 
57
  @spaces.GPU
58
  def talk(message, history):
59
+ retrieved_examples = search(message)
60
+ message, metadata = prepare_prompt(message, retrieved_examples)
61
  resources = "\nRESOURCES:\n"
62
+ for title, url in metadata:
63
  resources += f"[{title}]({url}), "
64
  chat = []
65
  for item in history:
 
68
  cleaned_past = item[1].split("\nRESOURCES:\n")[0]
69
  chat.append({"role": "assistant", "content": cleaned_past})
70
  chat.append({"role": "user", "content": message})
71
+ messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
72
  # Tokenize the messages string
73
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
74
  streamer = TextIteratorStreamer(
75
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
76
  )
77
  generate_kwargs = dict(
78
  model_inputs,
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  spaces
2
  torch==2.2.0
3
  transformers
4
- ragatouille
5
  faiss-gpu
6
  datasets
 
1
  spaces
2
  torch==2.2.0
3
  transformers
4
+ sentence-transformers
5
  faiss-gpu
6
  datasets