Izza-shahzad-13 commited on
Commit
4363122
·
verified ·
1 Parent(s): b89f69f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -40
app.py CHANGED
@@ -1,49 +1,57 @@
1
  import streamlit as st
2
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
  import torch
4
 
5
- # Load the fine-tuned FLAN-T5 model and tokenizer
6
- rag_tokenizer = RagTokenizer.from_pretrained("fine_tuned_flan_t5") # Your fine-tuned model path
7
- rag_retriever = RagRetriever.from_pretrained("facebook/rag-token-nq") # Pre-trained retriever
8
- rag_model = RagSequenceForGeneration.from_pretrained("Izza-shahzad-13/fine-tuned-flan-t5") # Your fine-tuned model
9
 
10
- # Setup device
 
 
 
 
 
 
 
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
- rag_model.to(device)
13
-
14
- # Function to generate RAG response
15
- def generate_rag_response(input_text):
16
- # Tokenize the input
17
- inputs = rag_tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
18
-
19
- # Retrieve relevant documents based on input
20
- input_ids = inputs['input_ids'].to(device)
21
- retrieved_docs = rag_retriever(input_ids)
22
-
23
- # Generate the response from the retrieved context
24
- generated_ids = rag_model.generate(
25
- input_ids=input_ids,
26
- context_input_ids=retrieved_docs['context_input_ids'].to(device),
27
- context_attention_mask=retrieved_docs['context_attention_mask'].to(device),
28
- max_length=200,
29
- num_beams=4,
30
- top_p=0.9,
31
- top_k=50,
32
- temperature=0.7,
33
- no_repeat_ngram_size=3,
34
- early_stopping=True
35
- )
36
-
37
- # Decode the generated response
38
- return rag_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
39
-
40
- # Streamlit UI setup
41
- st.title("Mental Health Counseling Assistant with RAG")
42
-
43
- # Input for user query
44
  user_input = st.text_input("How are you feeling today?")
45
 
 
46
  if user_input:
47
- # Generate and display the response using the RAG model
48
- response = generate_rag_response(user_input)
49
  st.write("Model Response:", response)
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ # Hugging Face Token for Authentication
6
+ HUGGINGFACE_TOKEN = "your_hugging_face_token_here" # Replace with your token
 
 
7
 
8
+ # Function to load model and tokenizer (local or Hugging Face with token)
9
+ def load_model(model_path):
10
+ try:
11
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=HUGGINGFACE_TOKEN)
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, use_auth_token=HUGGINGFACE_TOKEN)
13
+ return tokenizer, model
14
+ except Exception as e:
15
+ st.error(f"Error loading model: {e}")
16
+ return None, None
17
+
18
+ # Set device (use GPU if available)
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Path to your model (either a local path or a Hugging Face model name)
22
+ model_path = "Izza-shahzad-13/fine-tuned-flan-t5" # Use your Hugging Face model identifier
23
+
24
+ # Load tokenizer and model
25
+ tokenizer, model = load_model(model_path)
26
+ if model:
27
+ model.to(device)
28
+
29
+ # Function to generate response from the model
30
+ def generate_response(input_text):
31
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
32
+ with torch.no_grad():
33
+ output = model.generate(
34
+ inputs['input_ids'],
35
+ max_length=500,
36
+ num_beams=4,
37
+ top_p=0.9,
38
+ top_k=50,
39
+ temperature=0.7,
40
+ do_sample=True,
41
+ no_repeat_ngram_size=3,
42
+ early_stopping=True
43
+ )
44
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
45
+ return response
46
+
47
+ # Streamlit app interface
48
+ st.title("FLAN-T5 Mental Health Counseling Assistant")
49
+ st.write("Type your thoughts or feelings, and let the model respond.")
50
+
51
+ # User input for interaction
52
  user_input = st.text_input("How are you feeling today?")
53
 
54
+ # Generate and display model response when input is provided
55
  if user_input:
56
+ response = generate_response(user_input)
 
57
  st.write("Model Response:", response)