seawolf2357 commited on
Commit
9ae4071
Β·
verified Β·
1 Parent(s): 44a6b17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -1,25 +1,27 @@
1
  import os
2
- import torch # torchλ₯Ό μž„ν¬νŠΈ
3
- import faiss
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
5
  from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset
 
7
  import gradio as gr
8
- from accelerate import Accelerator
9
 
10
- # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
- # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
16
- accelerator = Accelerator() # Accelerator μΈμŠ€ν„΄μŠ€ 생성
17
 
18
- # λͺ¨λΈ λ‘œλ”©
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  token=hf_api_key,
22
- torch_dtype=torch.bfloat16, # torchλ₯Ό μ‚¬μš©ν•΄ 데이터 νƒ€μž… 지정
23
  quantization_config=BitsAndBytesConfig(
24
  load_in_4bit=True,
25
  bnb_4bit_use_double_quant=True,
@@ -27,23 +29,20 @@ model = AutoModelForCausalLM.from_pretrained(
27
  bnb_4bit_compute_dtype=torch.bfloat16
28
  )
29
  )
30
- model = accelerator.prepare(model) # λͺ¨λΈμ„ Accelerator에 μ€€λΉ„μ‹œν‚΄
31
 
32
- # 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
33
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
34
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
35
  data = dataset["train"]
36
  data = data.add_faiss_index("embeddings")
37
 
38
- # 검색 및 응닡 생성 ν•¨μˆ˜
39
  def search(query: str, k: int = 3):
40
  embedded_query = ST.encode(query)
41
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
42
  return scores, retrieved_examples
43
 
44
- # λ‚˜λ¨Έμ§€ μ½”λ“œλŠ” 이전과 λ™μΌν•˜κ²Œ μœ μ§€
45
-
46
-
47
  def format_prompt(prompt, retrieved_documents, k):
48
  PROMPT = f"Question:{prompt}\nContext:"
49
  for idx in range(k):
@@ -51,7 +50,7 @@ def format_prompt(prompt, retrieved_documents, k):
51
  return PROMPT
52
 
53
  def generate(formatted_prompt):
54
- formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
55
  messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
56
  input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
57
  outputs = model.generate(
@@ -62,16 +61,17 @@ def generate(formatted_prompt):
62
  temperature=0.6,
63
  top_p=0.9
64
  )
65
- response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
66
- return response
67
 
68
  def rag_chatbot_interface(prompt: str, k: int = 2):
69
  scores, retrieved_documents = search(prompt, k)
70
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
71
  return generate(formatted_prompt)
72
 
 
73
  SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
74
 
 
75
  iface = gr.Interface(
76
  fn=rag_chatbot_interface,
77
  inputs=gr.inputs.Textbox(label="Enter your question"),
 
1
  import os
2
+ import torch
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ from accelerate import Accelerator
5
  from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset
7
+ import faiss
8
  import gradio as gr
 
9
 
10
+ # Set Hugging Face API key from environment variable
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # Define model ID
14
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
15
+
16
+ # Initialize tokenizer and model
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
18
+ accelerator = Accelerator()
19
 
20
+ # Load the model with custom quantization using BitsAndBytesConfig
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
  token=hf_api_key,
24
+ torch_dtype=torch.bfloat16,
25
  quantization_config=BitsAndBytesConfig(
26
  load_in_4bit=True,
27
  bnb_4bit_use_double_quant=True,
 
29
  bnb_4bit_compute_dtype=torch.bfloat16
30
  )
31
  )
32
+ model = accelerator.prepare(model)
33
 
34
+ # Load dataset and create FAISS index
35
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
36
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
37
  data = dataset["train"]
38
  data = data.add_faiss_index("embeddings")
39
 
40
+ # Define functions for search, prompt formatting, and generation
41
  def search(query: str, k: int = 3):
42
  embedded_query = ST.encode(query)
43
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
44
  return scores, retrieved_examples
45
 
 
 
 
46
  def format_prompt(prompt, retrieved_documents, k):
47
  PROMPT = f"Question:{prompt}\nContext:"
48
  for idx in range(k):
 
50
  return PROMPT
51
 
52
  def generate(formatted_prompt):
53
+ formatted_prompt = formatted_prompt[:2000] # Limit due to GPU memory constraints
54
  messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
55
  input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
56
  outputs = model.generate(
 
61
  temperature=0.6,
62
  top_p=0.9
63
  )
64
+ return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
65
 
66
  def rag_chatbot_interface(prompt: str, k: int = 2):
67
  scores, retrieved_documents = search(prompt, k)
68
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
69
  return generate(formatted_prompt)
70
 
71
+ # Define system prompt for the chatbot
72
  SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
73
 
74
+ # Set up Gradio interface
75
  iface = gr.Interface(
76
  fn=rag_chatbot_interface,
77
  inputs=gr.inputs.Textbox(label="Enter your question"),