Haseeb-001 commited on
Commit
f7f0da7
Β·
verified Β·
1 Parent(s): f929853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -38
app.py CHANGED
@@ -4,72 +4,90 @@ import faiss
4
  import pickle
5
  from groq import Groq
6
  from datasets import load_dataset
7
- from transformers import pipeline
8
 
9
  # Initialize Groq API
10
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
11
 
12
  # Load datasets
13
- healthcare_ds = load_dataset("harishnair04/mtsamples")
14
- education_ds = load_dataset("ehovy/race", "all")
15
- finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
 
 
 
 
16
 
17
- # Load chat model
18
- chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog")
 
 
 
 
 
19
 
20
- # FAISS Index Setup
21
- index = faiss.IndexFlatL2(768)
22
  chat_history = []
23
 
24
  # Streamlit UI Setup
25
  st.set_page_config(page_title="AI Chatbot", layout="wide")
26
  st.title("πŸ€– AI Chatbot (Healthcare, Education & Finance)")
27
 
28
- # Sidebar for chat history
29
- st.sidebar.title("πŸ“œ Chat History")
30
- if st.sidebar.button("Download Chat History"):
31
- with open("chat_history.txt", "w") as file:
32
- file.write("\n".join(chat_history))
33
- st.sidebar.success("Chat history saved!")
34
 
35
  # Chat Interface
36
  user_input = st.text_input("πŸ’¬ Ask me anything:", placeholder="Type your query here...")
37
  if st.button("Send"):
38
  if user_input:
39
- # Determine dataset based on user query (Basic CAG Implementation)
40
  dataset = healthcare_ds if "health" in user_input.lower() else \
41
  education_ds if "education" in user_input.lower() else \
42
- finance_ds
43
 
44
- # RAG: Retrieve relevant data
45
- retrieved_data = dataset['train'][0] # Simplified retrieval
46
-
47
- # Generate response using Llama via Groq API
48
- chat_completion = client.chat.completions.create(
49
- messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
50
- model="llama-3.3-70b-versatile"
51
- )
52
- response = chat_completion.choices[0].message.content
53
-
54
- # Save chat to FAISS and display
 
 
 
 
 
 
 
 
55
  chat_history.append(f"User: {user_input}\nBot: {response}")
56
  st.text_area("πŸ€– AI Response:", value=response, height=200)
57
 
58
- # Display past chats
59
- st.sidebar.write("\n".join(chat_history))
60
 
61
- # Save chat history using pickle for persistence
62
  def save_chat_history():
63
- with open("chat_history.pkl", "wb") as file:
64
- pickle.dump(chat_history, file)
 
 
 
 
65
 
66
  def load_chat_history():
67
  global chat_history
68
- if os.path.exists("chat_history.pkl"):
69
- with open("chat_history.pkl", "rb") as file:
70
- chat_history = pickle.load(file)
 
 
 
 
71
 
72
- load_chat_history()
73
  if st.sidebar.button("Save Chat History"):
74
- save_chat_history()
75
- st.sidebar.success("Chat history saved permanently!")
 
4
  import pickle
5
  from groq import Groq
6
  from datasets import load_dataset
7
+ from transformers import AutoTokenizer, pipeline
8
 
9
  # Initialize Groq API
10
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
11
 
12
  # Load datasets
13
+ try: # Handle potential dataset loading errors
14
+ healthcare_ds = load_dataset("harishnair04/mtsamples")
15
+ education_ds = load_dataset("ehovy/race", "all")
16
+ finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
17
+ except Exception as e:
18
+ st.error(f"Error loading datasets: {e}")
19
+ st.stop() # Stop execution if datasets fail to load
20
 
21
+ # Load chat model and tokenizer (with error handling and cache)
22
+ try:
23
+ tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
24
+ chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
25
+ except Exception as e:
26
+ st.error(f"Error loading chat model: {e}")
27
+ st.stop()
28
 
29
+ # FAISS Index Setup (Simplified)
30
+ index = faiss.IndexFlatL2(768) # Adjust dimension if needed
31
  chat_history = []
32
 
33
  # Streamlit UI Setup
34
  st.set_page_config(page_title="AI Chatbot", layout="wide")
35
  st.title("πŸ€– AI Chatbot (Healthcare, Education & Finance)")
36
 
37
+ # ... (rest of your Streamlit UI code - sidebar, input, buttons)
 
 
 
 
 
38
 
39
  # Chat Interface
40
  user_input = st.text_input("πŸ’¬ Ask me anything:", placeholder="Type your query here...")
41
  if st.button("Send"):
42
  if user_input:
43
+ # Determine dataset (Basic CAG)
44
  dataset = healthcare_ds if "health" in user_input.lower() else \
45
  education_ds if "education" in user_input.lower() else \
46
+ finance_ds if "finance" in user_input.lower() else None #Handle no dataset match
47
 
48
+ if dataset is None:
49
+ st.warning("No relevant dataset found for your query.")
50
+ st.stop()
51
+
52
+ # RAG: Retrieve (Simplified)
53
+ retrieved_data = dataset['train'][0] if dataset and len(dataset['train']) > 0 else "No relevant data retrieved." #Check dataset is not empty
54
+
55
+ try:
56
+ # Generate response (Groq)
57
+ chat_completion = client.chat.completions.create(
58
+ messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
59
+ model="llama-3.3-70b-versatile" #Ensure model name is correct
60
+ )
61
+ response = chat_completion.choices[0].message.content
62
+ except Exception as e:
63
+ st.error(f"Error generating response: {e}")
64
+ response = "Error generating response." #Provide default response in case of error
65
+
66
+ # Save and display
67
  chat_history.append(f"User: {user_input}\nBot: {response}")
68
  st.text_area("πŸ€– AI Response:", value=response, height=200)
69
 
70
+ # ... (rest of your Streamlit code - chat history display, save/load)
 
71
 
72
+ # Persistence functions (pickle)
73
  def save_chat_history():
74
+ try:
75
+ with open("chat_history.pkl", "wb") as file:
76
+ pickle.dump(chat_history, file)
77
+ st.sidebar.success("Chat history saved permanently!")
78
+ except Exception as e:
79
+ st.sidebar.error(f"Error saving chat history: {e}")
80
 
81
  def load_chat_history():
82
  global chat_history
83
+ try:
84
+ if os.path.exists("chat_history.pkl"):
85
+ with open("chat_history.pkl", "rb") as file:
86
+ chat_history = pickle.load(file)
87
+ except Exception as e:
88
+ st.sidebar.warning(f"Error loading chat history (may be corrupted): {e}")
89
+
90
 
91
+ load_chat_history() # Load on startup
92
  if st.sidebar.button("Save Chat History"):
93
+ save_chat_history()