sabssag commited on
Commit
8678a4f
·
verified ·
1 Parent(s): 7fd26f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -24
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
3
- import torch
4
 
5
- # Load the tokenizer and model
6
- model_name = './QAModel'
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForQuestionAnswering.from_pretrained('./results') # Path to your fine-tuned model
 
9
 
10
  # Set the title for the Streamlit app
11
  st.title("Movie Trivia Question Answering")
@@ -14,27 +14,17 @@ st.title("Movie Trivia Question Answering")
14
  context = st.text_area("Enter the context (movie-related text):")
15
  question = st.text_area("Enter your question:")
16
 
17
- def get_answer(context, question):
18
- inputs = tokenizer.encode_plus(question, context, return_tensors='pt', truncation=True, padding=True)
19
- input_ids = inputs['input_ids'].tolist()[0]
20
-
21
- # Get the model's answer
22
- outputs = model(**inputs)
23
- answer_start_scores = outputs.start_logits
24
- answer_end_scores = outputs.end_logits
25
-
26
- # Get the most likely beginning and end of the answer span
27
- answer_start = torch.argmax(answer_start_scores)
28
- answer_end = torch.argmax(answer_end_scores) + 1
29
-
30
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
31
- return answer
32
 
33
  if st.button("Get Answer"):
34
  if context and question:
35
- answer = get_answer(context, question)
36
- st.subheader("Answer")
37
- st.write(answer)
 
38
  else:
39
  st.warning("Please enter both context and question.")
40
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
3
 
4
+ # Define the path to the saved model
5
+ model_path = './results' # Path to your fine-tuned model
6
+
7
+ # Load the pipeline
8
+ qa_pipeline = pipeline("question-answering", model=model_path, tokenizer=model_path)
9
 
10
  # Set the title for the Streamlit app
11
  st.title("Movie Trivia Question Answering")
 
14
  context = st.text_area("Enter the context (movie-related text):")
15
  question = st.text_area("Enter your question:")
16
 
17
+ def generate_answer(question, context):
18
+ # Perform question answering
19
+ result = qa_pipeline(question=question, context=context)
20
+ return result['answer']
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  if st.button("Get Answer"):
23
  if context and question:
24
+ generated_answer = generate_answer(question, context)
25
+ # Display the generated answer
26
+ st.subheader("Generated Answer")
27
+ st.write(generated_answer)
28
  else:
29
  st.warning("Please enter both context and question.")
30