rasyosef commited on
Commit
e832688
1 Parent(s): 1e7cadb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -36,20 +36,7 @@ QA_PROMPT = PromptTemplate(
36
  model_id = "microsoft/phi-2"
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
39
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
40
-
41
- streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True)
42
- phi2 = pipeline(
43
- "text-generation",
44
- tokenizer=tokenizer,
45
- model=model,
46
- max_new_tokens=128,
47
- pad_token_id=tokenizer.eos_token_id,
48
- eos_token_id=tokenizer.eos_token_id,
49
- device_map="auto",
50
- streamer=streamer
51
- ) # GPU
52
- hf_model = HuggingFacePipeline(pipeline=phi2)
53
 
54
  # Returns a faiss vector store retriever given a txt file
55
  def prepare_vector_store_retriever(filename):
@@ -74,7 +61,7 @@ def prepare_vector_store_retriever(filename):
74
  return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2})
75
 
76
  # Retrieveal QA chian
77
- def get_retrieval_qa_chain(text_file):
78
  retriever = default_retriever
79
  if text_file != default_text_file:
80
  retriever = prepare_vector_store_retriever(text_file)
@@ -87,8 +74,15 @@ def get_retrieval_qa_chain(text_file):
87
  return chain
88
 
89
  # Generates response using the question answering chain defined earlier
90
- def generate(question, answer, retriever):
91
- qa_chain = get_retrieval_qa_chain(retriever)
 
 
 
 
 
 
 
92
 
93
  query = f"{question}"
94
 
@@ -130,7 +124,8 @@ with gr.Blocks() as demo:
130
  upload_button.upload(upload_file, upload_button, [file_name, text_file])
131
 
132
  gr.Markdown("## Enter your question")
133
-
 
134
  with gr.Row():
135
  with gr.Column():
136
  ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3)
@@ -142,7 +137,7 @@ with gr.Blocks() as demo:
142
  with gr.Column():
143
  clear = gr.ClearButton([ques, ans])
144
 
145
- btn.click(fn=generate, inputs=[ques, ans, text_file], outputs=[ans])
146
  examples = gr.Examples(
147
  examples=[
148
  "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",
 
36
  model_id = "microsoft/phi-2"
37
 
38
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
39
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Returns a faiss vector store retriever given a txt file
42
  def prepare_vector_store_retriever(filename):
 
61
  return VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2})
62
 
63
  # Retrieveal QA chian
64
+ def get_retrieval_qa_chain(text_file, hf_model):
65
  retriever = default_retriever
66
  if text_file != default_text_file:
67
  retriever = prepare_vector_store_retriever(text_file)
 
74
  return chain
75
 
76
  # Generates response using the question answering chain defined earlier
77
+ def generate(question, answer, text_file, max_new_tokens):
78
+ streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True)
79
+ phi2_pipeline = pipeline(
80
+ "text-generation", tokenizer=tokenizer, model=model, max_new_tokens=max_new_tokens,
81
+ pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
82
+ device_map="cpu", streamer=streamer
83
+ )
84
+ hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
85
+ qa_chain = get_retrieval_qa_chain(text_file, hf_model)
86
 
87
  query = f"{question}"
88
 
 
124
  upload_button.upload(upload_file, upload_button, [file_name, text_file])
125
 
126
  gr.Markdown("## Enter your question")
127
+ tokens_slider = gr.Slider(8, 256, value=64, label="Maximum new tokens", info="A larger `max_new_tokens` parameter value gives you longer text responses but at the cost of a slower response time.")
128
+
129
  with gr.Row():
130
  with gr.Column():
131
  ques = gr.Textbox(label="Question", placeholder="Enter text here", lines=3)
 
137
  with gr.Column():
138
  clear = gr.ClearButton([ques, ans])
139
 
140
+ btn.click(fn=generate, inputs=[ques, ans, text_file, tokens_slider], outputs=[ans])
141
  examples = gr.Examples(
142
  examples=[
143
  "Who portrayed J. Robert Oppenheimer in the new Oppenheimer movie?",