ChijoTheDatascientist
commited on
Commit
•
39ec79e
1
Parent(s):
4a54e93
Loading model once and error handling
Browse files
app.py
CHANGED
@@ -3,8 +3,6 @@ import torch
|
|
3 |
from huggingface_hub import InferenceClient
|
4 |
import streamlit as st
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
-
from langchain_core.prompts import PromptTemplate
|
7 |
-
from langchain_core.output_parsers import StrOutputParser
|
8 |
|
9 |
# Load HF_TOKEN securely
|
10 |
hf_token = os.getenv("HF_TOKEN")
|
@@ -16,38 +14,57 @@ client = InferenceClient(api_key=hf_token)
|
|
16 |
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
|
17 |
bart_model_path = "ChijoTheDatascientist/summarization-model"
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
|
|
|
|
|
|
|
|
24 |
@st.cache_data
|
25 |
def summarize_review(review_text):
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
31 |
def generate_response(system_message, user_input, chat_history, max_new_tokens=128):
|
32 |
try:
|
33 |
# Prepare the messages for the Hugging Face Inference API
|
34 |
messages = [{"role": "user", "content": user_input}]
|
35 |
-
|
36 |
-
# Call the Inference API
|
37 |
completion = client.chat.completions.create(
|
38 |
model=model_id,
|
39 |
messages=messages,
|
40 |
max_tokens=max_new_tokens,
|
41 |
)
|
42 |
-
|
43 |
-
# Get the response from the API
|
44 |
response = completion.choices[0].message["content"]
|
45 |
return response
|
46 |
-
|
|
|
47 |
except Exception as e:
|
48 |
-
return f"
|
|
|
49 |
|
50 |
-
#
|
51 |
st.set_page_config(page_title="Insight Snap & Summarizer")
|
52 |
st.title("Insight Snap & Summarizer")
|
53 |
|
@@ -66,34 +83,32 @@ user_input = st.text_area("Enter customer reviews or a question:")
|
|
66 |
|
67 |
if st.button("Submit"):
|
68 |
if user_input:
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
88 |
else:
|
89 |
-
st.warning("
|
90 |
-
else:
|
91 |
-
st.warning("Please specify if you want to 'summarize' or get 'insights'.")
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
else:
|
97 |
st.warning("Please enter customer reviews or ask for insights.")
|
98 |
-
|
99 |
-
|
|
|
3 |
from huggingface_hub import InferenceClient
|
4 |
import streamlit as st
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
6 |
|
7 |
# Load HF_TOKEN securely
|
8 |
hf_token = os.getenv("HF_TOKEN")
|
|
|
14 |
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
|
15 |
bart_model_path = "ChijoTheDatascientist/summarization-model"
|
16 |
|
17 |
+
# Cache the BART model and tokenizer
|
18 |
+
@st.cache_resource
|
19 |
+
def load_summarization_model():
|
20 |
+
device = torch.device('cpu')
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(bart_model_path)
|
22 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path).to(device)
|
23 |
+
return tokenizer, model
|
24 |
|
25 |
+
# Load the model and tokenizer
|
26 |
+
bart_tokenizer, bart_model = load_summarization_model()
|
27 |
+
|
28 |
+
# Summarize reviews
|
29 |
@st.cache_data
|
30 |
def summarize_review(review_text):
|
31 |
+
try:
|
32 |
+
if len(review_text) > 1000:
|
33 |
+
return "The review is too long for summarization. Please limit your text to about 1,000 characters, thank you!."
|
34 |
+
inputs = bart_tokenizer(review_text, max_length=1024, truncation=True, return_tensors="pt")
|
35 |
+
summary_ids = bart_model.generate(
|
36 |
+
inputs["input_ids"],
|
37 |
+
max_length=40,
|
38 |
+
min_length=10,
|
39 |
+
length_penalty=2.0,
|
40 |
+
num_beams=8,
|
41 |
+
early_stopping=True
|
42 |
+
)
|
43 |
+
summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
44 |
+
return f"Your review has been successfully summarized! Check the result below:\n\n{summary}"
|
45 |
+
except Exception as e:
|
46 |
+
return f"Something went wrong during the summarization process. Please try again. Error: {e}"
|
47 |
+
|
48 |
|
49 |
+
# Generate response
|
50 |
def generate_response(system_message, user_input, chat_history, max_new_tokens=128):
|
51 |
try:
|
52 |
# Prepare the messages for the Hugging Face Inference API
|
53 |
messages = [{"role": "user", "content": user_input}]
|
|
|
|
|
54 |
completion = client.chat.completions.create(
|
55 |
model=model_id,
|
56 |
messages=messages,
|
57 |
max_tokens=max_new_tokens,
|
58 |
)
|
|
|
|
|
59 |
response = completion.choices[0].message["content"]
|
60 |
return response
|
61 |
+
except ConnectionError:
|
62 |
+
return "we're having trouble connecting to the server. Please try again later."
|
63 |
except Exception as e:
|
64 |
+
return f"Oops! Something went wrong: {e}"
|
65 |
+
|
66 |
|
67 |
+
# App configuration
|
68 |
st.set_page_config(page_title="Insight Snap & Summarizer")
|
69 |
st.title("Insight Snap & Summarizer")
|
70 |
|
|
|
83 |
|
84 |
if st.button("Submit"):
|
85 |
if user_input:
|
86 |
+
# Show a loading spinner while processing
|
87 |
+
with st.spinner("Processing..."):
|
88 |
+
# Summarize if the query is feedback-related
|
89 |
+
if "summarize" in user_input.lower():
|
90 |
+
summary = summarize_review(user_input)
|
91 |
+
st.markdown(f"**Summary:** \n{summary}")
|
92 |
+
elif "insight" in user_input.lower() or "feedback" in user_input.lower():
|
93 |
+
system_message = (
|
94 |
+
"You are a helpful assistant providing actionable insights "
|
95 |
+
"from customer feedback to help businesses improve their services."
|
96 |
+
)
|
97 |
+
last_summary = st.session_state.get("last_summary", "")
|
98 |
+
query_input = last_summary if last_summary else user_input
|
99 |
+
response = generate_response(system_message, query_input, st.session_state.chat_history)
|
100 |
+
|
101 |
+
if response:
|
102 |
+
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
103 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
104 |
+
st.markdown(f"**Insight:** \n{response}")
|
105 |
+
else:
|
106 |
+
st.warning("No response generated. Please try again later.")
|
107 |
else:
|
108 |
+
st.warning("Please specify if you want to 'summarize' or get 'insights'.")
|
|
|
|
|
109 |
|
110 |
+
|
111 |
+
if "summarize" in user_input.lower():
|
112 |
+
st.session_state["last_summary"] = summary
|
113 |
else:
|
114 |
st.warning("Please enter customer reviews or ask for insights.")
|
|
|
|