seawolf2357 commited on
Commit
4cc10ce
Β·
verified Β·
1 Parent(s): 2d84b3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -1,37 +1,34 @@
1
  import os
2
- import torch
3
- import faiss
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Accelerate
5
  from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset
 
7
  import gradio as gr
8
 
9
- # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
10
- hf_api_key = os.getenv('HF_API_KEY')
11
 
12
- # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
13
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
14
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_id,
17
  token=hf_api_key,
18
  torch_dtype=torch.bfloat16,
19
- device_map="auto",
20
  quantization_config=BitsAndBytesConfig(
21
- load_in_4bit=True,
22
- bnb_4bit_use_double_quant=True,
23
- bnb_4bit_quant_type="nf4",
24
  bnb_4bit_compute_dtype=torch.bfloat16
25
  )
26
  )
 
27
 
28
- # 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
29
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
30
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
31
  data = dataset["train"]
32
  data = data.add_faiss_index("embeddings")
33
 
34
- # 검색 및 응닡 생성 ν•¨μˆ˜
35
  def search(query: str, k: int = 3):
36
  embedded_query = ST.encode(query)
37
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
@@ -45,8 +42,8 @@ def format_prompt(prompt, retrieved_documents, k):
45
 
46
  def generate(formatted_prompt):
47
  formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
48
- messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
49
- input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
50
  outputs = model.generate(
51
  input_ids,
52
  max_new_tokens=1024,
@@ -65,13 +62,12 @@ def rag_chatbot_interface(prompt: str, k: int = 2):
65
 
66
  SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
67
 
68
- # Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
69
  iface = gr.Interface(
70
  fn=rag_chatbot_interface,
71
  inputs=gr.inputs.Textbox(label="Enter your question"),
72
  outputs=gr.outputs.Textbox(label="Answer"),
73
  title="Retrieval-Augmented Generation Chatbot",
74
- description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
75
  )
76
 
77
  iface.launch()
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from accelerate import Accelerator # Accelerateλ₯Ό λ³„λ„λ‘œ μž„ν¬νŠΈ
 
4
  from sentence_transformers import SentenceTransformer
5
  from datasets import load_dataset
6
+ import faiss
7
  import gradio as gr
8
 
9
+ hf_api_key = os.getenv('HF_API_KEY') # ν™˜κ²½ λ³€μˆ˜μ—μ„œ API ν‚€ λ‘œλ“œ
 
10
 
 
11
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
12
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
13
+ accelerator = Accelerator() # Accelerator μΈμŠ€ν„΄μŠ€ 생성
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
16
  token=hf_api_key,
17
  torch_dtype=torch.bfloat16,
 
18
  quantization_config=BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_quant_type="nf4",
22
  bnb_4bit_compute_dtype=torch.bfloat16
23
  )
24
  )
25
+ model = accelerator.prepare(model) # λͺ¨λΈμ„ Accelerator에 μ€€λΉ„μ‹œν‚΄
26
 
 
27
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
28
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
29
  data = dataset["train"]
30
  data = data.add_faiss_index("embeddings")
31
 
 
32
  def search(query: str, k: int = 3):
33
  embedded_query = ST.encode(query)
34
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
 
42
 
43
  def generate(formatted_prompt):
44
  formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
45
+ messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
46
+ input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
47
  outputs = model.generate(
48
  input_ids,
49
  max_new_tokens=1024,
 
62
 
63
  SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
64
 
 
65
  iface = gr.Interface(
66
  fn=rag_chatbot_interface,
67
  inputs=gr.inputs.Textbox(label="Enter your question"),
68
  outputs=gr.outputs.Textbox(label="Answer"),
69
  title="Retrieval-Augmented Generation Chatbot",
70
+ description="This chatbot uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
71
  )
72
 
73
  iface.launch()