arampacha commited on
Commit
22f88f8
1 Parent(s): f3d44ff

pipeline app

Browse files
Files changed (3) hide show
  1. app.py +19 -46
  2. app_api.py +111 -0
  3. requirements.txt +4 -1
app.py CHANGED
@@ -1,30 +1,12 @@
 
1
  import streamlit as st
2
- import json
3
- import requests
4
 
5
  import time
6
 
 
7
 
8
- API_TOKEN = st.secrets["hf_api_token"]
9
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
10
- API_URL = "https://api-inference.huggingface.co/models/arampacha/DialoGPT-medium-simpsons"
11
-
12
- def query(payload):
13
- data = json.dumps(payload)
14
- response = requests.request("POST", API_URL, headers=headers, data=data)
15
- return json.loads(response.content.decode("utf-8"))
16
-
17
- def fake_query(payload):
18
- user_input = payload["inputs"]["text"]
19
- time.sleep(1)
20
- res = {
21
- "generated_text": user_input[::-1],
22
- "conversation":{
23
- "past_user_inputs": st.session_state.past_user_inputs + [user_input],
24
- "generated_responses": st.session_state.generated_responses + [user_input[::-1]],
25
- },
26
- }
27
- return res
28
 
29
  parameters = {
30
  "min_length":None,
@@ -34,40 +16,31 @@ parameters = {
34
  "repetition_penalty":None,
35
  }
36
 
37
- options = {
38
- "use_cache":False,
39
- "wait_for_model":False
40
- }
41
 
42
  def on_input():
43
- # st.write("Input changed")
44
  if st.session_state.count > 0:
45
  user_input = st.session_state.user_input
46
- st.session_state.full_text += f"_User_ >>> {user_input}\n\n"
47
  dialog_output.markdown(st.session_state.full_text)
48
  st.session_state.user_input = ""
49
 
50
- payload = {
51
- "inputs": {
52
- "text": user_input,
53
- "past_user_inputs": st.session_state.past_user_inputs,
54
- "generated_responses": st.session_state.generated_responses,
55
- },
56
- "parameters": parameters,
57
- "options":options,
58
- }
59
- # result = fake_query(payload)
60
- result = query(payload)
61
  try:
62
- st.session_state.update(result["conversation"])
63
- st.session_state.full_text += f'_Chatbot_ > {result["generated_text"]}\n\n'
64
- except:
 
 
 
65
  st.write("D'oh! Something went wrong. Try to rerun the app.")
66
- st.write(result)
67
  st.session_state.count += 1
68
 
69
-
70
-
71
  # init session state
72
  if "past_user_inputs" not in st.session_state:
73
  st.session_state["past_user_inputs"] = []
@@ -96,7 +69,7 @@ if st.session_state.count > 0:
96
  dialog_output.markdown(st.session_state.full_text)
97
 
98
  user_input = st.text_input(
99
- "> User: ",
100
  # value="Hey Homer! How is it going?",
101
  on_change=on_input(),
102
  key="user_input",
 
1
+ import os
2
  import streamlit as st
3
+ from transformers import pipeline, Conversation
 
4
 
5
  import time
6
 
7
+ model_id = "arampacha/DialoGPT-medium-simpsons"
8
 
9
+ dialog = pipeline("conversational", model=model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  parameters = {
12
  "min_length":None,
 
16
  "repetition_penalty":None,
17
  }
18
 
 
 
 
 
19
 
20
  def on_input():
 
21
  if st.session_state.count > 0:
22
  user_input = st.session_state.user_input
23
+ st.session_state.full_text += f"_user_ >>> {user_input}\n\n"
24
  dialog_output.markdown(st.session_state.full_text)
25
  st.session_state.user_input = ""
26
 
27
+ conv = Conversation(
28
+ text = user_input,
29
+ past_user_inputs = st.session_state.past_user_inputs,
30
+ generated_responses = st.session_state.generated_responses,
31
+ )
32
+ conv = dialog(conv)
 
 
 
 
 
33
  try:
34
+ st.session_state.update({
35
+ "past_user_inputs": conv.past_user_inputs,
36
+ "generated_responses": conv.generated_responses,
37
+ })
38
+ st.session_state.full_text += f'_chatbot_ > {conv.generated_text[-1]}\n\n'
39
+ except Exception as e:
40
  st.write("D'oh! Something went wrong. Try to rerun the app.")
41
+ st.write(e)
42
  st.session_state.count += 1
43
 
 
 
44
  # init session state
45
  if "past_user_inputs" not in st.session_state:
46
  st.session_state["past_user_inputs"] = []
 
69
  dialog_output.markdown(st.session_state.full_text)
70
 
71
  user_input = st.text_input(
72
+ "user >> ",
73
  # value="Hey Homer! How is it going?",
74
  on_change=on_input(),
75
  key="user_input",
app_api.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import requests
4
+
5
+ import time
6
+
7
+
8
+ API_TOKEN = st.secrets["hf_api_token"]
9
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
10
+ API_URL = "https://api-inference.huggingface.co/models/arampacha/DialoGPT-medium-simpsons"
11
+
12
+ def query(payload):
13
+ data = json.dumps(payload)
14
+ response = requests.request("POST", API_URL, headers=headers, data=data)
15
+ return json.loads(response.content.decode("utf-8"))
16
+
17
+ def fake_query(payload):
18
+ user_input = payload["inputs"]["text"]
19
+ time.sleep(1)
20
+ res = {
21
+ "generated_text": user_input[::-1],
22
+ "conversation":{
23
+ "past_user_inputs": st.session_state.past_user_inputs + [user_input],
24
+ "generated_responses": st.session_state.generated_responses + [user_input[::-1]],
25
+ },
26
+ }
27
+ return res
28
+
29
+ parameters = {
30
+ "min_length":None,
31
+ "max_length":100,
32
+ "top_p":0.92,
33
+ "temperature":1.0,
34
+ "repetition_penalty":None,
35
+ }
36
+
37
+ options = {
38
+ "use_cache":False,
39
+ "wait_for_model":False
40
+ }
41
+
42
+ def on_input():
43
+ # st.write("Input changed")
44
+ if st.session_state.count > 0:
45
+ user_input = st.session_state.user_input
46
+ st.session_state.full_text += f"_User_ >>> {user_input}\n\n"
47
+ dialog_output.markdown(st.session_state.full_text)
48
+ st.session_state.user_input = ""
49
+
50
+ payload = {
51
+ "inputs": {
52
+ "text": user_input,
53
+ "past_user_inputs": st.session_state.past_user_inputs,
54
+ "generated_responses": st.session_state.generated_responses,
55
+ },
56
+ "parameters": parameters,
57
+ "options":options,
58
+ }
59
+ # result = fake_query(payload)
60
+ result = query(payload)
61
+ try:
62
+ st.session_state.update(result["conversation"])
63
+ st.session_state.full_text += f'_Chatbot_ > {result["generated_text"]}\n\n'
64
+ except:
65
+ st.write("D'oh! Something went wrong. Try to rerun the app.")
66
+ st.write(result)
67
+ st.session_state.count += 1
68
+
69
+
70
+
71
+ # init session state
72
+ if "past_user_inputs" not in st.session_state:
73
+ st.session_state["past_user_inputs"] = []
74
+ if "generated_responses" not in st.session_state:
75
+ st.session_state["generated_responses"] = []
76
+ if "full_text" not in st.session_state:
77
+ st.session_state["full_text"] = ""
78
+ if "user_input" not in st.session_state:
79
+ st.session_state["user_input"] = ""
80
+ if "count" not in st.session_state:
81
+ st.session_state["count"] = 0
82
+
83
+ # body
84
+ st.title("Chat with Simpsons")
85
+
86
+ st.image(
87
+ "https://raw.githubusercontent.com/arampacha/chat-with-simpsons/main/the-simpsons.png",
88
+ caption="(c) 20th Century Fox Television",
89
+ )
90
+ if st.session_state.count == 0:
91
+ st.write("Start dialog by inputing some text:")
92
+
93
+ dialog_output = st.empty()
94
+
95
+ if st.session_state.count > 0:
96
+ dialog_output.markdown(st.session_state.full_text)
97
+
98
+ user_input = st.text_input(
99
+ "> User: ",
100
+ # value="Hey Homer! How is it going?",
101
+ on_change=on_input(),
102
+ key="user_input",
103
+ )
104
+
105
+ dialog_text = st.session_state.full_text
106
+ dialog_output.markdown(dialog_text)
107
+
108
+ def restart():
109
+ st.session_state.clear()
110
+
111
+ st.button("Restart", on_click=st.session_state.clear)
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- streamlit==0.84.1
 
 
 
 
1
+ -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
+ torch==1.10.2+cpu
3
+ streamlit==0.84.1
4
+ transformers