seawolf2357 commited on
Commit
12218a1
Β·
verified Β·
1 Parent(s): 3d962a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -10,16 +10,14 @@ from accelerate import Accelerator
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
- # λͺ¨λΈ ID μ„€μ •
14
  model_id = "microsoft/phi-2"
15
-
16
- # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©κ³Ό ν•¨κ»˜ λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_id,
20
  token=hf_api_key,
21
- trust_remote_code=True, # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©
22
- torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ‚¬μš©
23
  )
24
 
25
  # Accelerator μ„€μ •
@@ -32,11 +30,10 @@ dataset = load_dataset("not-lain/wikipedia", revision="embedded")
32
  data = dataset["train"]
33
  data = data.add_faiss_index("embeddings")
34
 
35
- # 검색, ν”„λ‘¬ν”„νŠΈ ν¬λ§·νŒ…, 응닡 생성 ν•¨μˆ˜
36
  def search(query: str, k: int = 3):
37
  embedded_query = ST.encode(query)
38
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
39
- return scores, retrieved_documents
40
 
41
  def format_prompt(prompt, retrieved_documents, k):
42
  PROMPT = f"Question:{prompt}\nContext:"
@@ -55,16 +52,15 @@ def rag_chatbot_interface(prompt: str, k: int = 2):
55
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
56
  return generate(formatted_prompt)
57
 
58
- # μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ
59
  SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."
60
 
61
- # Gradio μΈν„°νŽ˜μ΄μŠ€
62
  iface = gr.Interface(
63
  fn=rag_chatbot_interface,
64
  inputs="text",
65
  outputs="text",
66
  title="Retrieval-Augmented Generation Chatbot",
67
- description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
 
68
  )
69
 
70
  iface.launch()
 
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
  model_id = "microsoft/phi-2"
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
  token=hf_api_key,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.float32
21
  )
22
 
23
  # Accelerator μ„€μ •
 
30
  data = dataset["train"]
31
  data = data.add_faiss_index("embeddings")
32
 
 
33
  def search(query: str, k: int = 3):
34
  embedded_query = ST.encode(query)
35
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
36
+ return scores, retrieved_examples
37
 
38
  def format_prompt(prompt, retrieved_documents, k):
39
  PROMPT = f"Question:{prompt}\nContext:"
 
52
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
53
  return generate(formatted_prompt)
54
 
 
55
  SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."
56
 
 
57
  iface = gr.Interface(
58
  fn=rag_chatbot_interface,
59
  inputs="text",
60
  outputs="text",
61
  title="Retrieval-Augmented Generation Chatbot",
62
+ description="This chatbot provides more accurate answers by searching relevant documents and generating responses.",
63
+ share=True # 곡개 링크 생성
64
  )
65
 
66
  iface.launch()