seawolf2357 commited on
Commit
dbd7c99
Β·
verified Β·
1 Parent(s): 4ad9b62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -9,7 +9,13 @@ from accelerate import Accelerator
9
 
10
  hf_api_key = os.getenv('HF_API_KEY')
11
  model_id = "microsoft/phi-2"
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
 
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
  token=hf_api_key,
@@ -27,14 +33,13 @@ data = data.add_faiss_index("embeddings")
27
 
28
  def generate(formatted_prompt):
29
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
30
- # ν† ν¬λ‚˜μ΄μ§• μ‹œ attention_mask도 ν•¨κ»˜ 생성
31
  encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
32
  input_ids = encoding['input_ids'].to(accelerator.device)
33
  attention_mask = encoding['attention_mask'].to(accelerator.device)
34
 
35
  outputs = model.generate(
36
  input_ids,
37
- attention_mask=attention_mask, # attention_mask 전달
38
  max_new_tokens=1024,
39
  eos_token_id=tokenizer.eos_token_id,
40
  do_sample=True,
@@ -43,7 +48,6 @@ def generate(formatted_prompt):
43
  )
44
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
-
47
  def search(query: str, k: int = 3):
48
  embedded_query = ST.encode(query)
49
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
 
9
 
10
  hf_api_key = os.getenv('HF_API_KEY')
11
  model_id = "microsoft/phi-2"
12
+
13
+ # ν† ν¬λ‚˜μ΄μ € 및 λͺ¨λΈ μ„€μ •
14
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
15
+ # νŒ¨λ”© 토큰 μ„€μ •
16
+ if tokenizer.pad_token is None:
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  token=hf_api_key,
 
33
 
34
  def generate(formatted_prompt):
35
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
 
36
  encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
37
  input_ids = encoding['input_ids'].to(accelerator.device)
38
  attention_mask = encoding['attention_mask'].to(accelerator.device)
39
 
40
  outputs = model.generate(
41
  input_ids,
42
+ attention_mask=attention_mask,
43
  max_new_tokens=1024,
44
  eos_token_id=tokenizer.eos_token_id,
45
  do_sample=True,
 
48
  )
49
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
50
 
 
51
  def search(query: str, k: int = 3):
52
  embedded_query = ST.encode(query)
53
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)