dhmeltzer commited on
Commit
8df9ec0
1 Parent(s): 9087ce2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -25
app.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np
2
  import requests
3
  import streamlit as st
4
  import openai
 
5
 
6
  def main():
7
  st.title("Scientific Question Generation")
@@ -12,16 +13,27 @@ def main():
12
  of the [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) model and the [GPT-3.5-turbo](https://platform.openai.com/docs/models/gpt-3-5) model.\
13
  \n\n For a more thorough discussion of question generation see this [report](https://wandb.ai/dmeltzer/Question_Generation/reports/Exploratory-Data-Analysis-for-r-AskScience--Vmlldzo0MjQwODg1?accessToken=fndbu2ar26mlbzqdphvb819847qqth2bxyi4hqhugbnv97607mj01qc7ed35v6w8) for EDA on the r/AskScience dataset and this \
14
  [report](https://api.wandb.ai/links/dmeltzer/7an677es) for details on our training procedure.\
15
- \n\n**Disclaimer**: You may recieve an error message when you first run the model. We are using the Huggingface API to access the BART-Large and FLAN-T5 models, and the inference API takes around 20 seconds to load each model.\
 
 
 
 
16
  ")
17
 
18
- checkpoints = ['dhmeltzer/bart-large_askscience-qg',
19
- 'dhmeltzer/flan-t5-base_askscience-qg',
20
- 'google/flan-t5-xxl']
21
 
22
- headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}
 
 
 
 
 
 
23
  openai.api_key = st.secrets['OpenAI_token']
24
-
 
25
  def query(checkpoint, payload):
26
  API_URL = f"https://api-inference.huggingface.co/models/{checkpoint}"
27
 
@@ -36,12 +48,41 @@ def main():
36
  """Black holes are the most gravitationally dense objects in the universe.""")
37
 
38
  if user_input:
39
- for checkpoint in checkpoints:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  model_name = checkpoint.split('/')[1]
42
 
 
43
  if 'flan' in model_name.lower():
44
-
45
  prompt = 'generate a question: ' + user_input
46
 
47
  else:
@@ -57,23 +98,6 @@ def main():
57
  return
58
 
59
  st.write(f'**{model_name}**: {output}')
60
-
61
- model_engine = "gpt-3.5-turbo"
62
- max_tokens = 50
63
-
64
- prompt = f"generate a question: {user_input}"
65
-
66
- response=openai.ChatCompletion.create(
67
- model=model_engine,
68
- messages=[
69
- {"role": "system", "content": "You are a helpful assistant that generates questions from text."},
70
- {"role": "user", "content": prompt},
71
- ])
72
-
73
- output = response['choices'][0]['message']['content']
74
-
75
- st.write(f'**{model_engine}**: {output}')
76
-
77
 
78
  if __name__ == "__main__":
79
  main()
 
2
  import requests
3
  import streamlit as st
4
  import openai
5
+ import json
6
 
7
  def main():
8
  st.title("Scientific Question Generation")
 
13
  of the [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) model and the [GPT-3.5-turbo](https://platform.openai.com/docs/models/gpt-3-5) model.\
14
  \n\n For a more thorough discussion of question generation see this [report](https://wandb.ai/dmeltzer/Question_Generation/reports/Exploratory-Data-Analysis-for-r-AskScience--Vmlldzo0MjQwODg1?accessToken=fndbu2ar26mlbzqdphvb819847qqth2bxyi4hqhugbnv97607mj01qc7ed35v6w8) for EDA on the r/AskScience dataset and this \
15
  [report](https://api.wandb.ai/links/dmeltzer/7an677es) for details on our training procedure.\
16
+ \n\n \
17
+ The two fine-tuned models (BART-Large and FLAN-T5-Base) are hosted on AWS using a combination of AWS Sagemaker, Lambda, and API gateway. \
18
+ \ GPT-3.5 is called using the OpenAI API and the FLAN-T5-XXL model is hosted by HuggingFace and is called with their Inference API.\
19
+ \n \n \
20
+ **Disclaimer**: You may recieve an error message when calling the FLAN-T5-XXL model since the Inference API takes around 20 seconds to load the model.\
21
  ")
22
 
23
+ AWS_checkpoints = {}
24
+ AWS_checkpoints['BART-Large']='https://8hlnvys7bh.execute-api.us-east-1.amazonaws.com/beta/'
25
+ AWS_checkpoints['FLAN-T5-Base']='https://gnrxh05827.execute-api.us-east-1.amazonaws.com/beta/'
26
 
27
+ # Right now HF_checkpoints just consists of FLAN-T5-XXL but we may add more models later.
28
+ HF_checkpoints = ['google/flan-t5-xxl']
29
+
30
+ # Token to access HF inference API
31
+ HF_headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"}
32
+
33
+ # Token to access OpenAI API
34
  openai.api_key = st.secrets['OpenAI_token']
35
+
36
+ # Used to query models hosted on Huggingface
37
  def query(checkpoint, payload):
38
  API_URL = f"https://api-inference.huggingface.co/models/{checkpoint}"
39
 
 
48
  """Black holes are the most gravitationally dense objects in the universe.""")
49
 
50
  if user_input:
51
+
52
+ for name, url in AWS_checkpoints.values():
53
+ headers={'x-api-key': key}
54
+
55
+ input_data = json.dumps({'inputs':user_input})
56
+ r = requests.get(url,data=input_data,headers=headers)
57
+ output = r.json()[0]['generated_text']
58
+
59
+ st.write(f'**{name}**: {output}')
60
+
61
+ model_engine = "gpt-3.5-turbo"
62
+ # Max tokens to produce
63
+ max_tokens = 50
64
+
65
+ # Prompt GPT-3.5 with an explicit question
66
+ prompt = f"generate a question: {user_input}"
67
+
68
+ # We give GPT-3.5 a message so it knows to generate questions from text.
69
+ response=openai.ChatCompletion.create(
70
+ model=model_engine,
71
+ messages=[
72
+ {"role": "system", "content": "You are a helpful assistant that generates questions from text."},
73
+ {"role": "user", "content": prompt},
74
+ ])
75
+
76
+ output = response['choices'][0]['message']['content']
77
+ st.write(f'**{model_engine}**: {output}')
78
+
79
+
80
+ for checkpoint in HF_checkpoints:
81
 
82
  model_name = checkpoint.split('/')[1]
83
 
84
+ # For FLAN models we need to give them instructions explicitly.
85
  if 'flan' in model_name.lower():
 
86
  prompt = 'generate a question: ' + user_input
87
 
88
  else:
 
98
  return
99
 
100
  st.write(f'**{model_name}**: {output}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  if __name__ == "__main__":
103
  main()