Haseeb-001 commited on
Commit
578ac78
·
verified ·
1 Parent(s): f7f0da7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -22
app.py CHANGED
@@ -5,25 +5,38 @@ import pickle
5
  from groq import Groq
6
  from datasets import load_dataset
7
  from transformers import AutoTokenizer, pipeline
 
8
 
9
  # Initialize Groq API
10
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
11
 
12
- # Load datasets
13
- try: # Handle potential dataset loading errors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  healthcare_ds = load_dataset("harishnair04/mtsamples")
15
  education_ds = load_dataset("ehovy/race", "all")
16
  finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
17
  except Exception as e:
18
  st.error(f"Error loading datasets: {e}")
19
- st.stop() # Stop execution if datasets fail to load
20
-
21
- # Load chat model and tokenizer (with error handling and cache)
22
- try:
23
- tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
24
- chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
25
- except Exception as e:
26
- st.error(f"Error loading chat model: {e}")
27
  st.stop()
28
 
29
  # FAISS Index Setup (Simplified)
@@ -40,28 +53,32 @@ st.title("🤖 AI Chatbot (Healthcare, Education & Finance)")
40
  user_input = st.text_input("💬 Ask me anything:", placeholder="Type your query here...")
41
  if st.button("Send"):
42
  if user_input:
43
- # Determine dataset (Basic CAG)
44
- dataset = healthcare_ds if "health" in user_input.lower() else \
45
- education_ds if "education" in user_input.lower() else \
46
- finance_ds if "finance" in user_input.lower() else None #Handle no dataset match
47
-
 
 
 
 
48
  if dataset is None:
49
- st.warning("No relevant dataset found for your query.")
50
  st.stop()
51
 
52
- # RAG: Retrieve (Simplified)
53
- retrieved_data = dataset['train'][0] if dataset and len(dataset['train']) > 0 else "No relevant data retrieved." #Check dataset is not empty
54
 
55
  try:
56
  # Generate response (Groq)
57
  chat_completion = client.chat.completions.create(
58
  messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
59
- model="llama-3.3-70b-versatile" #Ensure model name is correct
60
  )
61
  response = chat_completion.choices[0].message.content
62
  except Exception as e:
63
  st.error(f"Error generating response: {e}")
64
- response = "Error generating response." #Provide default response in case of error
65
 
66
  # Save and display
67
  chat_history.append(f"User: {user_input}\nBot: {response}")
@@ -87,7 +104,6 @@ def load_chat_history():
87
  except Exception as e:
88
  st.sidebar.warning(f"Error loading chat history (may be corrupted): {e}")
89
 
90
-
91
- load_chat_history() # Load on startup
92
  if st.sidebar.button("Save Chat History"):
93
  save_chat_history()
 
5
  from groq import Groq
6
  from datasets import load_dataset
7
  from transformers import AutoTokenizer, pipeline
8
+ import subprocess # For downloading if needed
9
 
10
  # Initialize Groq API
11
  client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
12
 
13
+ # Download model (if necessary - try requirements.txt first)
14
+ try:
15
+ # Try loading directly (after requirements.txt)
16
+ tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
17
+ chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
18
+ print("Model loaded successfully (direct load).") # Check in logs
19
+ except Exception as e:
20
+ try:
21
+ # Fallback: Download using subprocess (less preferred)
22
+ print("Trying to download model...") # Check in logs
23
+ subprocess.run(["transformers-cli", "download", "rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog"], check=True) # Updated download command
24
+ tokenizer = AutoTokenizer.from_pretrained("rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", cache_dir="./.cache")
25
+ chat_pipe = pipeline("text-generation", model="rajkumarrrk/dialogpt-fine-tuned-on-daily-dialog", tokenizer=tokenizer, cache_dir="./.cache")
26
+ print("Model downloaded and loaded successfully (subprocess).") # Check in logs
27
+ except Exception as download_e:
28
+ st.error(f"Error loading/downloading chat model: {e}. Download error: {download_e}")
29
+ st.stop()
30
+
31
+
32
+
33
+ # Load datasets (with error handling)
34
+ try:
35
  healthcare_ds = load_dataset("harishnair04/mtsamples")
36
  education_ds = load_dataset("ehovy/race", "all")
37
  finance_ds = load_dataset("warwickai/financial_phrasebank_mirror")
38
  except Exception as e:
39
  st.error(f"Error loading datasets: {e}")
 
 
 
 
 
 
 
 
40
  st.stop()
41
 
42
  # FAISS Index Setup (Simplified)
 
53
  user_input = st.text_input("💬 Ask me anything:", placeholder="Type your query here...")
54
  if st.button("Send"):
55
  if user_input:
56
+ # Dataset Selection (Improved)
57
+ dataset = None
58
+ if "health" in user_input.lower():
59
+ dataset = healthcare_ds
60
+ elif "education" in user_input.lower():
61
+ dataset = education_ds
62
+ elif "finance" in user_input.lower():
63
+ dataset = finance_ds
64
+
65
  if dataset is None:
66
+ st.warning("No relevant dataset found for your query. Please use keywords like 'health', 'education', or 'finance'.")
67
  st.stop()
68
 
69
+ # RAG: Retrieve (Simplified and safer)
70
+ retrieved_data = dataset['train'][0]['text'] if dataset and len(dataset['train']) > 0 and 'text' in dataset['train'][0] else "No relevant data retrieved."
71
 
72
  try:
73
  # Generate response (Groq)
74
  chat_completion = client.chat.completions.create(
75
  messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}],
76
+ model="llama-3.3-70b-versatile"
77
  )
78
  response = chat_completion.choices[0].message.content
79
  except Exception as e:
80
  st.error(f"Error generating response: {e}")
81
+ response = "Error generating response."
82
 
83
  # Save and display
84
  chat_history.append(f"User: {user_input}\nBot: {response}")
 
104
  except Exception as e:
105
  st.sidebar.warning(f"Error loading chat history (may be corrupted): {e}")
106
 
107
+ load_chat_history()
 
108
  if st.sidebar.button("Save Chat History"):
109
  save_chat_history()