adi2606 commited on
Commit
bfeb2cf
·
verified ·
1 Parent(s): 4d5f240

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -45
app.py CHANGED
@@ -1,49 +1,69 @@
1
  import streamlit as st
 
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
- # Set up the device to use CPU only
6
- device = torch.device("cpu")
7
-
8
- # Load model and tokenizer, then move the model to the appropriate device
9
- model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device)
10
- tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
11
-
12
- # Function to generate a response from the chatbot
13
- def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str:
14
- # Apply the chat template and convert to PyTorch tensors
15
- messages = [
16
- {"role": "system", "content": "You are a helpful assistant."},
17
- {"role": "user", "content": message}
18
- ]
19
- input_ids = tokenizer.apply_chat_template(
20
- messages, add_generation_prompt=True, return_tensors="pt"
21
- ).to(device)
22
-
23
- # Generate the response
24
- output = model.generate(
25
- input_ids,
26
- max_length=512,
27
- temperature=temperature,
28
- repetition_penalty=repetition_penalty,
29
- do_sample=True
30
- )
31
-
32
- # Decode the generated output
33
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
34
- return generated_text
35
-
36
- # Streamlit app layout
37
- st.title("Menstrual QA Chatbot")
38
- st.write("Ask any question related to menstrual health.")
39
-
40
- # User input
41
- user_input = st.text_input("You:", "")
42
-
43
- if st.button("Send"):
44
- if user_input:
45
- with st.spinner("Generating response..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  response = generate_response(user_input)
47
- st.write(f"Chatbot: {response}")
 
48
  else:
49
- st.write("Please enter a question.")
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
+ from PIL import Image
5
+ import base64
6
+
7
+
8
+ st.set_page_config(page_title="Menstrual Health Chatbot 💬", layout="centered")
9
+
10
+
11
+ @st.cache_resource
12
+ def load_model():
13
+ tokenizer = AutoTokenizer.from_pretrained("adi2606/Menstrual-Health-Awareness-Chatbot")
14
+ model = AutoModelForSeq2SeqLM.from_pretrained("adi2606/Menstrual-Health-Awareness-Chatbot").to("cpu")
15
+ return tokenizer, model
16
+
17
+ tokenizer, model = load_model()
18
+
19
+
20
+ def generate_response(input_text):
21
+ inputs = tokenizer(input_text, return_tensors="pt")
22
+ outputs = model.generate(**inputs, max_length=128)
23
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+ return response
25
+
26
+
27
+ def set_background(image_path):
28
+ with open(image_path, "rb") as f:
29
+ encoded_string = base64.b64encode(f.read()).decode()
30
+ st.markdown(
31
+ f"""
32
+ <style>
33
+ .stApp {{
34
+ background-image: url("data:image/png;base64,{encoded_string}");
35
+ background-size: cover;
36
+ }}
37
+ </style>
38
+ """,
39
+ unsafe_allow_html=True
40
+ )
41
+
42
+
43
+ set_background("background.jpg")
44
+
45
+
46
+ banner = Image.open("banner.png")
47
+ st.image(banner, use_column_width=True)
48
+
49
+
50
+ st.markdown("<h1 style='text-align: center;'>🩸 Menstrual Health Awareness Chatbot 💬</h1>", unsafe_allow_html=True)
51
+ st.markdown("<h4 style='text-align: center;'>Ask anything about periods, PMS, hygiene, and more!</h4>", unsafe_allow_html=True)
52
+
53
+
54
+ st.markdown("### 🤔 Your Question")
55
+ user_input = st.text_input("", placeholder="E.g., What are the symptoms of PMS?")
56
+
57
+
58
+ if st.button("🚀 Ask"):
59
+ if user_input.strip():
60
+ with st.spinner("Generating a helpful response..."):
61
  response = generate_response(user_input)
62
+ st.success("✅ Here's what I found:")
63
+ st.markdown(f"**💬 Chatbot:** {response}")
64
  else:
65
+ st.warning("⚠️ Please enter a question to get started.")
66
+
67
+
68
+ st.markdown("---")
69
+ st.markdown("<small style='color: gray;'>Made with ❤️ to spread awareness.</small>", unsafe_allow_html=True)