seawolf2357 commited on
Commit
d2de08e
Β·
verified Β·
1 Parent(s): 836a3ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -1,42 +1,37 @@
1
  import os
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
- from accelerate import Accelerator
5
  from sentence_transformers import SentenceTransformer
6
  from datasets import load_dataset
7
  import faiss
8
  import gradio as gr
 
9
 
10
- # Set Hugging Face API key from environment variable
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
- # Define model ID
14
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
15
-
16
- # Initialize tokenizer and model
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
18
  accelerator = Accelerator()
19
 
20
- # Load the model with custom quantization using BitsAndBytesConfig
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
  token=hf_api_key,
24
- torch_dtype=torch.bfloat16,
25
- quantization_config=BitsAndBytesConfig(
26
- load_in_4bit=True,
27
- bnb_4bit_use_double_quant=True,
28
- bnb_4bit_quant_type="nf4",
29
- bnb_4bit_compute_dtype=torch.bfloat16
30
- )
31
  )
32
  model = accelerator.prepare(model)
33
 
34
- # Load dataset and create FAISS index
35
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
36
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
37
  data = dataset["train"]
38
  data = data.add_faiss_index("embeddings")
39
 
 
 
 
40
  # Define functions for search, prompt formatting, and generation
41
  def search(query: str, k: int = 3):
42
  embedded_query = ST.encode(query)
 
1
  import os
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  from sentence_transformers import SentenceTransformer
5
  from datasets import load_dataset
6
  import faiss
7
  import gradio as gr
8
+ from accelerate import Accelerator
9
 
10
+ # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
16
  accelerator = Accelerator()
17
 
18
+ # μ–‘μžν™” μ„€μ • 없이 λͺ¨λΈ λ‘œλ“œ (문제 해결을 μœ„ν•œ μž„μ‹œ 쑰치)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  token=hf_api_key,
22
+ torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ‚¬μš©
 
 
 
 
 
 
23
  )
24
  model = accelerator.prepare(model)
25
 
26
+ # 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
27
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
28
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
29
  data = dataset["train"]
30
  data = data.add_faiss_index("embeddings")
31
 
32
+ # 기타 ν•¨μˆ˜ 및 Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성은 이전과 동일
33
+
34
+
35
  # Define functions for search, prompt formatting, and generation
36
  def search(query: str, k: int = 3):
37
  embedded_query = ST.encode(query)