seawolf2357 commited on
Commit
44a6b17
ยท
verified ยท
1 Parent(s): 4cc10ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -1,20 +1,25 @@
1
  import os
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
- from accelerate import Accelerator # Accelerate๋ฅผ ๋ณ„๋„๋กœ ์ž„ํฌํŠธ
4
  from sentence_transformers import SentenceTransformer
5
  from datasets import load_dataset
6
- import faiss
7
  import gradio as gr
 
8
 
9
- hf_api_key = os.getenv('HF_API_KEY') # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ API ํ‚ค ๋กœ๋“œ
 
10
 
 
11
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
12
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
13
  accelerator = Accelerator() # Accelerator ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
 
 
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_id,
16
  token=hf_api_key,
17
- torch_dtype=torch.bfloat16,
18
  quantization_config=BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_use_double_quant=True,
@@ -24,16 +29,21 @@ model = AutoModelForCausalLM.from_pretrained(
24
  )
25
  model = accelerator.prepare(model) # ๋ชจ๋ธ์„ Accelerator์— ์ค€๋น„์‹œํ‚ด
26
 
 
27
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
28
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
29
  data = dataset["train"]
30
  data = data.add_faiss_index("embeddings")
31
 
 
32
  def search(query: str, k: int = 3):
33
  embedded_query = ST.encode(query)
34
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
35
  return scores, retrieved_examples
36
 
 
 
 
37
  def format_prompt(prompt, retrieved_documents, k):
38
  PROMPT = f"Question:{prompt}\nContext:"
39
  for idx in range(k):
 
1
  import os
2
+ import torch # torch๋ฅผ ์ž„ํฌํŠธ
3
+ import faiss
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
5
  from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset
 
7
  import gradio as gr
8
+ from accelerate import Accelerator
9
 
10
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face API ํ‚ค ๋กœ๋“œ
11
+ hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # ๋ชจ๋ธ ID ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์„ค์ •
14
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
16
  accelerator = Accelerator() # Accelerator ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
17
+
18
+ # ๋ชจ๋ธ ๋กœ๋”ฉ
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  token=hf_api_key,
22
+ torch_dtype=torch.bfloat16, # torch๋ฅผ ์‚ฌ์šฉํ•ด ๋ฐ์ดํ„ฐ ํƒ€์ž… ์ง€์ •
23
  quantization_config=BitsAndBytesConfig(
24
  load_in_4bit=True,
25
  bnb_4bit_use_double_quant=True,
 
29
  )
30
  model = accelerator.prepare(model) # ๋ชจ๋ธ์„ Accelerator์— ์ค€๋น„์‹œํ‚ด
31
 
32
+ # ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ๋ฐ faiss ์ธ๋ฑ์Šค ์ƒ์„ฑ
33
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
34
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
35
  data = dataset["train"]
36
  data = data.add_faiss_index("embeddings")
37
 
38
+ # ๊ฒ€์ƒ‰ ๋ฐ ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜
39
  def search(query: str, k: int = 3):
40
  embedded_query = ST.encode(query)
41
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
42
  return scores, retrieved_examples
43
 
44
+ # ๋‚˜๋จธ์ง€ ์ฝ”๋“œ๋Š” ์ด์ „๊ณผ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€
45
+
46
+
47
  def format_prompt(prompt, retrieved_documents, k):
48
  PROMPT = f"Question:{prompt}\nContext:"
49
  for idx in range(k):