seawolf2357 commited on
Commit
3d962a1
Β·
verified Β·
1 Parent(s): 363bbc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -38
app.py CHANGED
@@ -10,40 +10,33 @@ from accelerate import Accelerator
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
- # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
- # λͺ¨λΈ ID
15
  model_id = "microsoft/phi-2"
16
 
17
- # μ‚¬μš©μž μ •μ˜ μ½”λ“œλ₯Ό μ‹ λ’°ν•˜κ³  μ‹€ν–‰ν•˜λ„λ‘ μ„€μ •
18
- model = AutoModelForCausalLM.from_pretrained(
19
- model_id,
20
- trust_remote_code=True # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©
21
- )
22
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
23
- accelerator = Accelerator()
24
-
25
- # μ–‘μžν™” μ„€μ • 없이 λͺ¨λΈ λ‘œλ“œ (문제 해결을 μœ„ν•œ μž„μ‹œ 쑰치)
26
  model = AutoModelForCausalLM.from_pretrained(
27
  model_id,
28
  token=hf_api_key,
 
29
  torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ‚¬μš©
30
  )
 
 
 
31
  model = accelerator.prepare(model)
32
 
33
- # 데이터 λ‘œλ”© 및 faiss 인덱슀 생성
34
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
35
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
36
  data = dataset["train"]
37
  data = data.add_faiss_index("embeddings")
38
 
39
- # 기타 ν•¨μˆ˜ 및 Gradio μΈν„°νŽ˜μ΄μŠ€ ꡬ성은 이전과 동일
40
-
41
-
42
- # Define functions for search, prompt formatting, and generation
43
  def search(query: str, k: int = 3):
44
  embedded_query = ST.encode(query)
45
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
46
- return scores, retrieved_examples
47
 
48
  def format_prompt(prompt, retrieved_documents, k):
49
  PROMPT = f"Question:{prompt}\nContext:"
@@ -52,39 +45,26 @@ def format_prompt(prompt, retrieved_documents, k):
52
  return PROMPT
53
 
54
  def generate(formatted_prompt):
55
- # ν”„λ‘¬ν”„νŠΈλ₯Ό λ¬Έμžμ—΄λ‘œ κ²°ν•©
56
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
57
- # ν† ν¬λ‚˜μ΄μ§•
58
  input_ids = tokenizer(prompt_text, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
59
- # 응닡 생성
60
- outputs = model.generate(
61
- input_ids,
62
- max_new_tokens=1024,
63
- eos_token_id=tokenizer.eos_token_id,
64
- do_sample=True,
65
- temperature=0.6,
66
- top_p=0.9
67
- )
68
- # 응닡 ν…μŠ€νŠΈλ‘œ λ””μ½”λ”©
69
- response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
70
- return response
71
-
72
 
73
  def rag_chatbot_interface(prompt: str, k: int = 2):
74
  scores, retrieved_documents = search(prompt, k)
75
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
76
  return generate(formatted_prompt)
77
 
78
- # Define system prompt for the chatbot
79
- SYS_PROMPT = "You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say 'I do not know.' Don't make up an answer."
80
-
81
 
 
82
  iface = gr.Interface(
83
  fn=rag_chatbot_interface,
84
- inputs="text", # ν…μŠ€νŠΈ μž…λ ₯
85
- outputs="text", # ν…μŠ€νŠΈ 좜λ ₯
86
  title="Retrieval-Augmented Generation Chatbot",
87
- description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
88
  )
89
 
90
  iface.launch()
 
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # λͺ¨λΈ ID μ„€μ •
 
14
  model_id = "microsoft/phi-2"
15
 
16
+ # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©κ³Ό ν•¨κ»˜ λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
 
 
 
 
 
 
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_id,
20
  token=hf_api_key,
21
+ trust_remote_code=True, # μ‚¬μš©μž μ •μ˜ μ½”λ“œ μ‹€ν–‰ ν—ˆμš©
22
  torch_dtype=torch.float32 # κΈ°λ³Έ dtype μ‚¬μš©
23
  )
24
+
25
+ # Accelerator μ„€μ •
26
+ accelerator = Accelerator()
27
  model = accelerator.prepare(model)
28
 
29
+ # 데이터셋 및 FAISS 인덱슀 λ‘œλ“œ
30
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
31
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
32
  data = dataset["train"]
33
  data = data.add_faiss_index("embeddings")
34
 
35
+ # 검색, ν”„λ‘¬ν”„νŠΈ ν¬λ§·νŒ…, 응닡 생성 ν•¨μˆ˜
 
 
 
36
  def search(query: str, k: int = 3):
37
  embedded_query = ST.encode(query)
38
  scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
39
+ return scores, retrieved_documents
40
 
41
  def format_prompt(prompt, retrieved_documents, k):
42
  PROMPT = f"Question:{prompt}\nContext:"
 
45
  return PROMPT
46
 
47
  def generate(formatted_prompt):
 
48
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
 
49
  input_ids = tokenizer(prompt_text, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
50
+ outputs = model.generate(input_ids, max_new_tokens=1024, eos_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.6, top_p=0.9)
51
+ return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def rag_chatbot_interface(prompt: str, k: int = 2):
54
  scores, retrieved_documents = search(prompt, k)
55
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
56
  return generate(formatted_prompt)
57
 
58
+ # μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ
59
+ SYS_PROMPT = "You are an assistant for answering questions. Provide a conversational answer."
 
60
 
61
+ # Gradio μΈν„°νŽ˜μ΄μŠ€
62
  iface = gr.Interface(
63
  fn=rag_chatbot_interface,
64
+ inputs="text",
65
+ outputs="text",
66
  title="Retrieval-Augmented Generation Chatbot",
67
+ description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
68
  )
69
 
70
  iface.launch()