amaltese commited on
Commit
ffe6783
·
verified ·
1 Parent(s): f7717dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -14
app.py CHANGED
@@ -1,6 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def get_response(user_input):
2
- """Generate a thoughtful response that includes a follow-up question."""
3
- history = "\n".join(st.session_state.conversation[-5:]) # Keep only the last 5 turns
 
4
  prompt = (
5
  f"You are a knowledgeable study coach. Engage the student in conversation. "
6
  f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n"
@@ -8,19 +39,42 @@ def get_response(user_input):
8
  f"Student: {user_input}\n"
9
  f"Coach: "
10
  )
11
-
12
- # Tokenize input with padding and attention mask
13
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
14
- input_ids = inputs.input_ids.to(model.device)
15
- attention_mask = inputs.attention_mask.to(model.device)
16
-
17
  with torch.no_grad():
18
  output = model.generate(
19
- input_ids,
20
- attention_mask=attention_mask,
21
- max_length=300,
22
- pad_token_id=tokenizer.eos_token_id # Ensures correct token handling
 
 
23
  )
24
 
25
- response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
26
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch
4
+
5
+ st.title("📚 Study Buddy Chatbot")
6
+ st.write("Ask a question or type a topic, and I'll help you learn interactively!")
7
+
8
+ # Initialize session state for conversation history
9
+ if "conversation" not in st.session_state:
10
+ st.session_state.conversation = []
11
+
12
+ # Load model with better caching and memory management
13
+ @st.cache_resource
14
+ def load_model():
15
+ MODEL_NAME = "HuggingFaceH4/zephyr-7b-alpha"
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.float16,
20
+ device_map="auto",
21
+ low_cpu_mem_usage=True
22
+ )
23
+ return tokenizer, model
24
+
25
+ # Only load model when needed
26
+ if "model_loaded" not in st.session_state:
27
+ with st.spinner("Loading AI model (this may take a minute)..."):
28
+ tokenizer, model = load_model()
29
+ st.session_state.model_loaded = True
30
+
31
  def get_response(user_input):
32
+ # Format conversation history for context
33
+ history = "\n".join(st.session_state.conversation[-6:]) # Last 6 exchanges
34
+
35
  prompt = (
36
  f"You are a knowledgeable study coach. Engage the student in conversation. "
37
  f"Ask open-ended questions to deepen understanding. Provide feedback and encourage explanations.\n\n"
 
39
  f"Student: {user_input}\n"
40
  f"Coach: "
41
  )
42
+
43
+ # Better generation parameters
44
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
 
 
 
45
  with torch.no_grad():
46
  output = model.generate(
47
+ input_ids,
48
+ max_new_tokens=250,
49
+ temperature=0.7,
50
+ top_p=0.9,
51
+ do_sample=True,
52
+ repetition_penalty=1.2
53
  )
54
 
55
+ response = tokenizer.decode(output[0, input_ids.shape[1]:], skip_special_tokens=True)
56
+ return response
57
+
58
+ # User interface
59
+ user_input = st.text_input("Type your question or topic:")
60
+
61
+ if user_input:
62
+ with st.spinner("Thinking..."):
63
+ response = get_response(user_input)
64
+
65
+ # Add to conversation history
66
+ st.session_state.conversation.append(f"Student: {user_input}")
67
+ st.session_state.conversation.append(f"Coach: {response}")
68
+
69
+ # Display conversation in a better format
70
+ st.subheader("Conversation History")
71
+ for i, message in enumerate(st.session_state.conversation[-10:]):
72
+ if i % 2 == 0: # Student messages
73
+ st.markdown(f"**You**: {message.replace('Student: ', '')}")
74
+ else: # Coach messages
75
+ st.markdown(f"**Coach**: {message.replace('Coach: ', '')}")
76
+
77
+ # Add a clear conversation button
78
+ if st.button("Clear Conversation"):
79
+ st.session_state.conversation = []
80
+ st.experimental_rerun()