seawolf2357 commited on
Commit
1b6e08f
Β·
verified Β·
1 Parent(s): 396d718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -40
app.py CHANGED
@@ -1,54 +1,66 @@
1
  from sentence_transformers import SentenceTransformer
2
- from datasets import load_dataset
3
- import gradio as gr
 
 
4
 
5
- ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
6
-
7
- dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
9
  data = dataset["train"]
10
- data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
11
 
12
- def search(query: str, k: int = 3 ):
13
- """a function that embeds a new query and returns the most probable results"""
14
- embedded_query = ST.encode(query) # embed new query
15
- scores, retrieved_examples = data.get_nearest_examples( # retrieve results
16
- "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
17
- k=k # get only top k results
18
  )
19
  return scores, retrieved_examples
20
 
21
- def format_prompt(prompt,retrieved_documents,k):
22
- """using the retrieved documents we will prompt the model to generate our responses"""
23
- PROMPT = f"Question:{prompt}\nContext:"
24
- for idx in range(k) :
25
- PROMPT+= f"{retrieved_documents['text'][idx]}\n"
26
- return PROMPT
27
 
28
  def generate(formatted_prompt):
29
- formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
30
- messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
31
- # tell the model to generate
32
- input_ids = tokenizer.apply_chat_template(
33
- messages,
34
- add_generation_prompt=True,
35
- return_tensors="pt"
36
- ).to(model.device)
37
- outputs = model.generate(
38
- input_ids,
39
- max_new_tokens=1024,
40
- eos_token_id=terminators,
41
- do_sample=True,
42
- temperature=0.6,
43
- top_p=0.9,
44
- )
45
- response = outputs[0][input_ids.shape[-1]:]
46
- return tokenizer.decode(response, skip_special_tokens=True)
 
 
47
 
48
- def rag_chatbot(prompt:str,k:int=2):
49
- scores , retrieved_documents = search(prompt, k)
50
- formatted_prompt = format_prompt(prompt,retrieved_documents,k)
51
- return generate(formatted_prompt)
52
 
53
  def rag_chatbot_interface(prompt:str,k:int=2):
54
  scores , retrieved_documents = search(prompt, k)
 
1
  from sentence_transformers import SentenceTransformer
2
+ from datasets import load_dataset, Dataset
3
+ import faiss # ν•„μš”ν•œ 경우 faissλ₯Ό μž„ν¬νŠΈν•©λ‹ˆλ‹€.
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+ import torch
6
 
7
+ # λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
8
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_id,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto",
14
+ quantization_config=BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_use_double_quant=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_compute_dtype=torch.bfloat16
19
+ )
20
+ )
21
 
22
+ # 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
23
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
24
+ dataset = load_dataset("not-lain/wikipedia", revision="embedded")
25
  data = dataset["train"]
26
+ data = data.add_faiss_index("embeddings")
27
 
28
+ # 검색 및 응닡 생성 ν•¨μˆ˜
29
+ def search(query: str, k: int = 3):
30
+ embedded_query = ST.encode(query)
31
+ scores, retrieved_examples = data.get_nearest_examples(
32
+ "embeddings", embedded_query, k=k
 
33
  )
34
  return scores, retrieved_examples
35
 
36
+ def format_prompt(prompt, retrieved_documents, k):
37
+ PROMPT = f"Question:{prompt}\nContext:"
38
+ for idx in range(k):
39
+ PROMPT += f"{retrieved_documents['text'][idx]}\n"
40
+ return PROMPT
 
41
 
42
  def generate(formatted_prompt):
43
+ formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
44
+ messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
45
+ input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
46
+ outputs = model.generate(
47
+ input_ids,
48
+ max_new_tokens=1024,
49
+ eos_token_id=[tokenizer.eos_token_id],
50
+ do_sample=True,
51
+ temperature=0.6,
52
+ top_p=0.9
53
+ )
54
+ response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
55
+ return response
56
+
57
+ def rag_chatbot(prompt: str, k: int = 2):
58
+ scores, retrieved_documents = search(prompt, k)
59
+ formatted_prompt = format_prompt(prompt, retrieved_documents, k)
60
+ return generate(formatted_prompt)
61
+
62
+ rag_chatbot("What is anarchy?", k=2)
63
 
 
 
 
 
64
 
65
  def rag_chatbot_interface(prompt:str,k:int=2):
66
  scores , retrieved_documents = search(prompt, k)