seawolf2357 commited on
Commit
9918198
Β·
verified Β·
1 Parent(s): 1f97769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -42
app.py CHANGED
@@ -1,11 +1,13 @@
 
 
1
  from sentence_transformers import SentenceTransformer
2
  from datasets import load_dataset, Dataset
3
- import faiss # ν•„μš”ν•œ 경우 faissλ₯Ό μž„ν¬νŠΈν•©λ‹ˆλ‹€.
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
  import torch
6
- import os
7
 
8
- tokenkey=os.getenv('HF_API_KEY')
 
9
 
10
  # λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
11
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
@@ -31,9 +33,7 @@ data = data.add_faiss_index("embeddings")
31
  # 검색 및 응닡 생성 ν•¨μˆ˜
32
  def search(query: str, k: int = 3):
33
  embedded_query = ST.encode(query)
34
- scores, retrieved_examples = data.get_nearest_examples(
35
- "embeddings", embedded_query, k=k
36
- )
37
  return scores, retrieved_examples
38
 
39
  def format_prompt(prompt, retrieved_documents, k):
@@ -44,12 +44,12 @@ def format_prompt(prompt, retrieved_documents, k):
44
 
45
  def generate(formatted_prompt):
46
  formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
47
- messages = [{"role": "system", "content": "You are an assistant..."}, {"role": "user", "content": formatted_prompt}]
48
  input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
49
  outputs = model.generate(
50
  input_ids,
51
  max_new_tokens=1024,
52
- eos_token_id=[tokenizer.eos_token_id],
53
  do_sample=True,
54
  temperature=0.6,
55
  top_p=0.9
@@ -57,42 +57,18 @@ def generate(formatted_prompt):
57
  response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
58
  return response
59
 
60
- def rag_chatbot(prompt: str, k: int = 2):
61
  scores, retrieved_documents = search(prompt, k)
62
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
63
  return generate(formatted_prompt)
64
 
65
- rag_chatbot("What is anarchy?", k=2)
66
-
67
-
68
- def rag_chatbot_interface(prompt:str,k:int=2):
69
- scores , retrieved_documents = search(prompt, k)
70
- formatted_prompt = format_prompt(prompt,retrieved_documents,k)
71
- return generate(formatted_prompt)
72
-
73
- SYS_PROMPT = """You are an assistant for answering questions.
74
- You are given the extracted parts of a long document and a question. Provide a conversational answer.
75
- If you don't know the answer, just say "I do not know." Don't make up an answer."""
76
-
77
- tokenizer = AutoTokenizer.from_pretrained(model_id)
78
- model = AutoModelForCausalLM.from_pretrained(
79
- model_id,
80
- torch_dtype=torch.bfloat16,
81
- device_map="auto",
82
- quantization_config=bnb_config
83
  )
84
- terminators = [
85
- tokenizer.eos_token_id,
86
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
87
- ]
88
-
89
- iface = gr.Interface(fn=rag_chatbot_interface,
90
- inputs="text",
91
- outputs="text",
92
- input_types=["text"],
93
- output_types=["text"],
94
- title="Retrieval-Augmented Generation Chatbot",
95
- description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
96
- )
97
 
98
- iface.launch()
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from sentence_transformers import SentenceTransformer
4
  from datasets import load_dataset, Dataset
5
+ import faiss
 
6
  import torch
7
+ import gradio as gr
8
 
9
+ # Hugging Face API ν‚€ ν™˜κ²½ λ³€μˆ˜ μ„€μ •
10
+ os.environ['HF_API_KEY'] = os.getenv('HF_API_KEY')
11
 
12
  # λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
13
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
 
33
  # 검색 및 응닡 생성 ν•¨μˆ˜
34
  def search(query: str, k: int = 3):
35
  embedded_query = ST.encode(query)
36
+ scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=k)
 
 
37
  return scores, retrieved_examples
38
 
39
  def format_prompt(prompt, retrieved_documents, k):
 
44
 
45
  def generate(formatted_prompt):
46
  formatted_prompt = formatted_prompt[:2000] # GPU λ©”λͺ¨λ¦¬ μ œν•œμ„ κ³ λ €
47
+ messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
48
  input_ids = tokenizer(messages, return_tensors="pt", padding=True).input_ids.to(model.device)
49
  outputs = model.generate(
50
  input_ids,
51
  max_new_tokens=1024,
52
+ eos_token_id=tokenizer.eos_token_id,
53
  do_sample=True,
54
  temperature=0.6,
55
  top_p=0.9
 
57
  response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
58
  return response
59
 
60
+ def rag_chatbot_interface(prompt: str, k: int = 2):
61
  scores, retrieved_documents = search(prompt, k)
62
  formatted_prompt = format_prompt(prompt, retrieved_documents, k)
63
  return generate(formatted_prompt)
64
 
65
+ # Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ •
66
+ iface = gr.Interface(
67
+ fn=rag_chatbot_interface,
68
+ inputs=gr.inputs.Textbox(label="Enter your question"),
69
+ outputs=gr.outputs.Textbox(label="Answer"),
70
+ title="Retrieval-Augmented Generation Chatbot",
71
+ description="This is a chatbot that uses a retrieval-augmented generation approach to provide more accurate answers. It first searches for relevant documents and then generates a response based on the prompt and the retrieved documents."
 
 
 
 
 
 
 
 
 
 
 
72
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ iface.launch()