seawolf2357 commited on
Commit
4ad9b62
Β·
verified Β·
1 Parent(s): d726220

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -0
app.py CHANGED
@@ -25,6 +25,25 @@ dataset = load_dataset("not-lain/wikipedia", revision="embedded")
25
  data = dataset["train"]
26
  data = data.add_faiss_index("embeddings")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def search(query: str, k: int = 3):
29
  embedded_query = ST.encode(query)
30
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
 
25
  data = dataset["train"]
26
  data = data.add_faiss_index("embeddings")
27
 
28
+ def generate(formatted_prompt):
29
+ prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
30
+ # ν† ν¬λ‚˜μ΄μ§• μ‹œ attention_mask도 ν•¨κ»˜ 생성
31
+ encoding = tokenizer(prompt_text, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
32
+ input_ids = encoding['input_ids'].to(accelerator.device)
33
+ attention_mask = encoding['attention_mask'].to(accelerator.device)
34
+
35
+ outputs = model.generate(
36
+ input_ids,
37
+ attention_mask=attention_mask, # attention_mask 전달
38
+ max_new_tokens=1024,
39
+ eos_token_id=tokenizer.eos_token_id,
40
+ do_sample=True,
41
+ temperature=0.6,
42
+ top_p=0.9
43
+ )
44
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+
46
+
47
  def search(query: str, k: int = 3):
48
  embedded_query = ST.encode(query)
49
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)