ChijoTheDatascientist commited on
Commit
39ec79e
1 Parent(s): 4a54e93

Loading model once and error handling

Browse files
Files changed (1) hide show
  1. app.py +59 -44
app.py CHANGED
@@ -3,8 +3,6 @@ import torch
3
  from huggingface_hub import InferenceClient
4
  import streamlit as st
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
- from langchain_core.prompts import PromptTemplate
7
- from langchain_core.output_parsers import StrOutputParser
8
 
9
  # Load HF_TOKEN securely
10
  hf_token = os.getenv("HF_TOKEN")
@@ -16,38 +14,57 @@ client = InferenceClient(api_key=hf_token)
16
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
17
  bart_model_path = "ChijoTheDatascientist/summarization-model"
18
 
19
- # Load BART model for summarization
20
- device = torch.device('cpu')
21
- bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_path)
22
- bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path).to(device)
 
 
 
23
 
 
 
 
 
24
  @st.cache_data
25
  def summarize_review(review_text):
26
- inputs = bart_tokenizer(review_text, max_length=1024, truncation=True, return_tensors="pt")
27
- summary_ids = bart_model.generate(inputs["input_ids"], max_length=40, min_length=10, length_penalty=2.0, num_beams=8, early_stopping=True)
28
- summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
29
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
31
  def generate_response(system_message, user_input, chat_history, max_new_tokens=128):
32
  try:
33
  # Prepare the messages for the Hugging Face Inference API
34
  messages = [{"role": "user", "content": user_input}]
35
-
36
- # Call the Inference API
37
  completion = client.chat.completions.create(
38
  model=model_id,
39
  messages=messages,
40
  max_tokens=max_new_tokens,
41
  )
42
-
43
- # Get the response from the API
44
  response = completion.choices[0].message["content"]
45
  return response
46
-
 
47
  except Exception as e:
48
- return f"Error generating response: {e}"
 
49
 
50
- # Streamlit app configuration
51
  st.set_page_config(page_title="Insight Snap & Summarizer")
52
  st.title("Insight Snap & Summarizer")
53
 
@@ -66,34 +83,32 @@ user_input = st.text_area("Enter customer reviews or a question:")
66
 
67
  if st.button("Submit"):
68
  if user_input:
69
- # Summarize if the query is feedback-related
70
- if "summarize" in user_input.lower():
71
- summary = summarize_review(user_input)
72
- st.markdown(f"**Summary:** \n{summary}")
73
- elif "insight" in user_input.lower() or "feedback" in user_input.lower():
74
- system_message = (
75
- "You are a helpful assistant providing actionable insights "
76
- "from customer feedback to help businesses improve their services."
77
- )
78
- # Use the last summarized text if available
79
- last_summary = st.session_state.get("last_summary", "")
80
- query_input = last_summary if last_summary else user_input
81
- response = generate_response(system_message, query_input, st.session_state.chat_history)
82
-
83
- if response:
84
- # Update chat history
85
- st.session_state.chat_history.append({"role": "user", "content": user_input})
86
- st.session_state.chat_history.append({"role": "assistant", "content": response})
87
- st.markdown(f"**Insight:** \n{response}")
 
 
88
  else:
89
- st.warning("No response generated. Please try again later.")
90
- else:
91
- st.warning("Please specify if you want to 'summarize' or get 'insights'.")
92
 
93
- # Store the last summary for insights
94
- if "summarize" in user_input.lower():
95
- st.session_state["last_summary"] = summary
96
  else:
97
  st.warning("Please enter customer reviews or ask for insights.")
98
-
99
-
 
3
  from huggingface_hub import InferenceClient
4
  import streamlit as st
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
6
 
7
  # Load HF_TOKEN securely
8
  hf_token = os.getenv("HF_TOKEN")
 
14
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
15
  bart_model_path = "ChijoTheDatascientist/summarization-model"
16
 
17
+ # Cache the BART model and tokenizer
18
+ @st.cache_resource
19
+ def load_summarization_model():
20
+ device = torch.device('cpu')
21
+ tokenizer = AutoTokenizer.from_pretrained(bart_model_path)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path).to(device)
23
+ return tokenizer, model
24
 
25
+ # Load the model and tokenizer
26
+ bart_tokenizer, bart_model = load_summarization_model()
27
+
28
+ # Summarize reviews
29
  @st.cache_data
30
  def summarize_review(review_text):
31
+ try:
32
+ if len(review_text) > 1000:
33
+ return "The review is too long for summarization. Please limit your text to about 1,000 characters, thank you!."
34
+ inputs = bart_tokenizer(review_text, max_length=1024, truncation=True, return_tensors="pt")
35
+ summary_ids = bart_model.generate(
36
+ inputs["input_ids"],
37
+ max_length=40,
38
+ min_length=10,
39
+ length_penalty=2.0,
40
+ num_beams=8,
41
+ early_stopping=True
42
+ )
43
+ summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
44
+ return f"Your review has been successfully summarized! Check the result below:\n\n{summary}"
45
+ except Exception as e:
46
+ return f"Something went wrong during the summarization process. Please try again. Error: {e}"
47
+
48
 
49
+ # Generate response
50
  def generate_response(system_message, user_input, chat_history, max_new_tokens=128):
51
  try:
52
  # Prepare the messages for the Hugging Face Inference API
53
  messages = [{"role": "user", "content": user_input}]
 
 
54
  completion = client.chat.completions.create(
55
  model=model_id,
56
  messages=messages,
57
  max_tokens=max_new_tokens,
58
  )
 
 
59
  response = completion.choices[0].message["content"]
60
  return response
61
+ except ConnectionError:
62
+ return "we're having trouble connecting to the server. Please try again later."
63
  except Exception as e:
64
+ return f"Oops! Something went wrong: {e}"
65
+
66
 
67
+ # App configuration
68
  st.set_page_config(page_title="Insight Snap & Summarizer")
69
  st.title("Insight Snap & Summarizer")
70
 
 
83
 
84
  if st.button("Submit"):
85
  if user_input:
86
+ # Show a loading spinner while processing
87
+ with st.spinner("Processing..."):
88
+ # Summarize if the query is feedback-related
89
+ if "summarize" in user_input.lower():
90
+ summary = summarize_review(user_input)
91
+ st.markdown(f"**Summary:** \n{summary}")
92
+ elif "insight" in user_input.lower() or "feedback" in user_input.lower():
93
+ system_message = (
94
+ "You are a helpful assistant providing actionable insights "
95
+ "from customer feedback to help businesses improve their services."
96
+ )
97
+ last_summary = st.session_state.get("last_summary", "")
98
+ query_input = last_summary if last_summary else user_input
99
+ response = generate_response(system_message, query_input, st.session_state.chat_history)
100
+
101
+ if response:
102
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
103
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
104
+ st.markdown(f"**Insight:** \n{response}")
105
+ else:
106
+ st.warning("No response generated. Please try again later.")
107
  else:
108
+ st.warning("Please specify if you want to 'summarize' or get 'insights'.")
 
 
109
 
110
+
111
+ if "summarize" in user_input.lower():
112
+ st.session_state["last_summary"] = summary
113
  else:
114
  st.warning("Please enter customer reviews or ask for insights.")