|
import os |
|
import torch |
|
from huggingface_hub import InferenceClient |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from langchain_core.prompts import PromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = InferenceClient(api_key=f"Bearer {hf_token}") |
|
|
|
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.3" |
|
bart_model_path = "ChijoTheDatascientist/summarization-model" |
|
|
|
|
|
device = torch.device('cpu') |
|
bart_tokenizer = AutoTokenizer.from_pretrained(bart_model_path) |
|
bart_model = AutoModelForSeq2SeqLM.from_pretrained(bart_model_path).to(device) |
|
|
|
@st.cache_data |
|
def summarize_review(review_text): |
|
inputs = bart_tokenizer(review_text, max_length=1024, truncation=True, return_tensors="pt") |
|
summary_ids = bart_model.generate(inputs["input_ids"], max_length=40, min_length=10, length_penalty=2.0, num_beams=8, early_stopping=True) |
|
summary = bart_tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|
|
def generate_response(system_message, user_input, chat_history, max_new_tokens=128): |
|
try: |
|
|
|
messages = [{"role": "user", "content": user_input}] |
|
|
|
|
|
completion = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=max_new_tokens, |
|
) |
|
|
|
|
|
response = completion.choices[0].message["content"] |
|
return response |
|
|
|
except Exception as e: |
|
return f"Error generating response: {e}" |
|
|
|
|
|
st.set_page_config(page_title="Insight Snap & Summarizer") |
|
st.title("Insight Snap & Summarizer") |
|
|
|
st.markdown(""" |
|
- Use specific keywords in your queries to get targeted responses: |
|
- **"summarize"**: To summarize customer reviews. |
|
- **"Feedback or insights"**: Get actionable business insights based on feedback. |
|
""") |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
|
|
user_input = st.text_area("Enter customer reviews or a question:") |
|
|
|
if st.button("Submit"): |
|
if user_input: |
|
|
|
if "summarize" in user_input.lower(): |
|
summary = summarize_review(user_input) |
|
st.markdown(f"**Summary:** \n{summary}") |
|
elif "insight" in user_input.lower() or "feedback" in user_input.lower(): |
|
system_message = ( |
|
"You are a helpful assistant providing actionable insights " |
|
"from customer feedback to help businesses improve their services." |
|
) |
|
|
|
last_summary = st.session_state.get("last_summary", "") |
|
query_input = last_summary if last_summary else user_input |
|
response = generate_response(system_message, query_input, st.session_state.chat_history) |
|
|
|
if response: |
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_input}) |
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
st.markdown(f"**Insight:** \n{response}") |
|
else: |
|
st.warning("No response generated. Please try again later.") |
|
else: |
|
st.warning("Please specify if you want to 'summarize' or get 'insights'.") |
|
|
|
|
|
if "summarize" in user_input.lower(): |
|
st.session_state["last_summary"] = summary |
|
else: |
|
st.warning("Please enter customer reviews or ask for insights.") |
|
|
|
|
|
|