|
import os |
|
import torch |
|
from huggingface_hub import InferenceClient |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = InferenceClient(api_key=hf_token) |
|
|
|
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.3" |
|
bart_model_path = "ChijoTheDatascientist/summarization-model" |
|
|
|
|
|
@st.cache_resource |
|
def load_summarization_model(): |
|
device = torch.device('cpu') |
|
tokenizer = AutoTokenizer.from_pretrained(bart_model_path) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path).to(device) |
|
return tokenizer, model |
|
|
|
|
|
bart_tokenizer, bart_model = load_summarization_model() |
|
|
|
|
|
@st.cache_data |
|
def summarize_review(review_text): |
|
try: |
|
if len(review_text) > 1000: |
|
return "The review is too long for summarization. Please limit your text to about 1,000 characters, thank you!." |
|
inputs = bart_tokenizer(review_text, max_length=1024, truncation=True, return_tensors="pt") |
|
summary_ids = bart_model.generate( |
|
inputs["input_ids"], |
|
max_length=40, |
|
min_length=10, |
|
length_penalty=2.0, |
|
num_beams=8, |
|
early_stopping=True |
|
) |
|
summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return f"Your review has been successfully summarized! Check the result below:\n\n{summary}" |
|
except Exception as e: |
|
return f"Something went wrong during the summarization process. Please try again. Error: {e}" |
|
|
|
|
|
|
|
def generate_response(system_message, user_input, chat_history, max_new_tokens=128): |
|
try: |
|
|
|
messages = [{"role": "user", "content": user_input}] |
|
completion = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_new_tokens, |
|
) |
|
response = completion.choices[0].message["content"] |
|
return response |
|
except ConnectionError: |
|
return "we're having trouble connecting to the server. Please try again later." |
|
except Exception as e: |
|
return f"Oops! Something went wrong: {e}" |
|
|
|
|
|
|
|
st.set_page_config(page_title="Insight Snap & Summarizer") |
|
st.title("Insight Snap & Summarizer") |
|
|
|
st.markdown(""" |
|
- Use specific keywords in your queries to get targeted responses: |
|
- **"summarize"**: To summarize customer reviews. |
|
- **"Feedback or insights"**: Get actionable business insights based on feedback. |
|
""") |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
user_input = st.text_area("Enter customer reviews or a question:") |
|
|
|
if st.button("Submit"): |
|
if user_input: |
|
|
|
with st.spinner("Processing..."): |
|
|
|
if "summarize" in user_input.lower(): |
|
summary = summarize_review(user_input) |
|
st.markdown(f"**Summary:** \n{summary}") |
|
elif "insight" in user_input.lower() or "feedback" in user_input.lower(): |
|
system_message = ( |
|
"You are a helpful assistant providing actionable insights " |
|
"from customer feedback to help businesses improve their services." |
|
) |
|
last_summary = st.session_state.get("last_summary", "") |
|
query_input = last_summary if last_summary else user_input |
|
response = generate_response(system_message, query_input, st.session_state.chat_history) |
|
|
|
if response: |
|
st.session_state.chat_history.append({"role": "user", "content": user_input}) |
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
st.markdown(f"**Insight:** \n{response}") |
|
else: |
|
st.warning("No response generated. Please try again later.") |
|
else: |
|
st.warning("Please specify if you want to 'summarize' or get 'insights'.") |
|
|
|
|
|
if "summarize" in user_input.lower(): |
|
st.session_state["last_summary"] = summary |
|
else: |
|
st.warning("Please enter customer reviews or ask for insights.") |
|
|