seawolf2357 commited on
Commit
85b887d
Β·
verified Β·
1 Parent(s): 12218a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -1,16 +1,16 @@
1
- import os
2
- import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from sentence_transformers import SentenceTransformer
5
  from datasets import load_dataset
6
  import faiss
7
  import gradio as gr
8
  from accelerate import Accelerator
 
 
9
 
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
- # λͺ¨λΈ ID 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
  model_id = "microsoft/phi-2"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
16
  model = AutoModelForCausalLM.from_pretrained(
@@ -20,11 +20,9 @@ model = AutoModelForCausalLM.from_pretrained(
20
  torch_dtype=torch.float32
21
  )
22
 
23
- # Accelerator μ„€μ •
24
  accelerator = Accelerator()
25
  model = accelerator.prepare(model)
26
 
27
- # 데이터셋 및 FAISS 인덱슀 λ‘œλ“œ
28
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
29
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
30
  data = dataset["train"]
@@ -44,7 +42,14 @@ def format_prompt(prompt, retrieved_documents, k):
44
  def generate(formatted_prompt):
45
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
46
  input_ids = tokenizer(prompt_text, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
47
- outputs = model.generate(input_ids, max_new_tokens=1024, eos_token_id=tokenizer.eos_token_id, do_sample=True, temperature=0.6, top_p=0.9)
 
 
 
 
 
 
 
48
  return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
49
 
50
  def rag_chatbot_interface(prompt: str, k: int = 2):
@@ -59,8 +64,7 @@ iface = gr.Interface(
59
  inputs="text",
60
  outputs="text",
61
  title="Retrieval-Augmented Generation Chatbot",
62
- description="This chatbot provides more accurate answers by searching relevant documents and generating responses.",
63
- share=True # 곡개 링크 생성
64
  )
65
 
66
- iface.launch()
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from sentence_transformers import SentenceTransformer
3
  from datasets import load_dataset
4
  import faiss
5
  import gradio as gr
6
  from accelerate import Accelerator
7
+ import os
8
+ import torch
9
 
10
  # ν™˜κ²½ λ³€μˆ˜μ—μ„œ Hugging Face API ν‚€ λ‘œλ“œ
11
  hf_api_key = os.getenv('HF_API_KEY')
12
 
13
+ # λͺ¨λΈ 및 ν† ν¬λ‚˜μ΄μ € μ„€μ •
14
  model_id = "microsoft/phi-2"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key, trust_remote_code=True)
16
  model = AutoModelForCausalLM.from_pretrained(
 
20
  torch_dtype=torch.float32
21
  )
22
 
 
23
  accelerator = Accelerator()
24
  model = accelerator.prepare(model)
25
 
 
26
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
27
  dataset = load_dataset("not-lain/wikipedia", revision="embedded")
28
  data = dataset["train"]
 
42
  def generate(formatted_prompt):
43
  prompt_text = f"{SYS_PROMPT} {formatted_prompt}"
44
  input_ids = tokenizer(prompt_text, return_tensors="pt", padding=True).input_ids.to(accelerator.device)
45
+ outputs = model.generate(
46
+ input_ids,
47
+ max_new_tokens=1024,
48
+ eos_token_id=tokenizer.eos_token_id,
49
+ do_sample=True,
50
+ temperature=0.6,
51
+ top_p=0.9
52
+ )
53
  return tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
54
 
55
  def rag_chatbot_interface(prompt: str, k: int = 2):
 
64
  inputs="text",
65
  outputs="text",
66
  title="Retrieval-Augmented Generation Chatbot",
67
+ description="This chatbot provides more accurate answers by searching relevant documents and generating responses."
 
68
  )
69
 
70
+ iface.launch(share=True) # μ—¬κΈ°μ—μ„œ share=Trueλ₯Ό μ„€μ •ν•˜μ—¬ 곡개 링크λ₯Ό 생성