Zeeshan42 commited on
Commit
bd41b38
·
verified ·
1 Parent(s): ed64e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py CHANGED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
2
+ from datasets import Dataset
3
+ from groq import Groq
4
+ import os
5
+
6
+ # Initialize Groq client with your API key
7
+ client = Groq(api_key="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")
8
+
9
+ # Paths to your books
10
+ book_paths = {
11
+ "DSM": "/content/Diagnostic and statistical manual of mental disorders _ DSM-5 ( PDFDrive.com ).pdf",
12
+ "Personality": "/content/b6c3v8_Theories_of_Personality_10.pdf",
13
+ "SearchForMeaning": "/content/Mans-Search-For-Meaning.pdf"
14
+ }
15
+
16
+ # Function to load and preprocess the data from books
17
+ def load_data(paths):
18
+ data = []
19
+ for title, path in paths.items():
20
+ with open(path, "r", encoding="utf-8", errors='ignore') as file:
21
+ text = file.read()
22
+ paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed)
23
+ for paragraph in paragraphs:
24
+ if paragraph.strip(): # Skip empty paragraphs
25
+ data.append({"text": paragraph.strip()})
26
+ return Dataset.from_list(data)
27
+
28
+ # Load and preprocess dataset for fine-tuning
29
+ dataset = load_data(book_paths)
30
+
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
32
+ from datasets import Dataset
33
+ from groq import Groq
34
+ import os
35
+
36
+ # Initialize Groq client with your API key
37
+ client = Groq(api_key="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")
38
+
39
+ # Paths to your books
40
+ book_paths = {
41
+ "DSM": "/app/Diagnostic and statistical manual of mental disorders _ DSM-5 ( PDFDrive.com ).pdf",
42
+ "Personality": "/app/b6c3v8_Theories_of_Personality_10.pdf",
43
+ "SearchForMeaning": "/app/Mans-Search-For-Meaning.pdf"
44
+ }
45
+
46
+ # Function to load and preprocess the data from books
47
+ def load_data(paths):
48
+ data = []
49
+ for title, path in paths.items():
50
+ with open(path, "r", encoding="utf-8", errors='ignore') as file:
51
+ text = file.read()
52
+ paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed)
53
+ for paragraph in paragraphs:
54
+ if paragraph.strip(): # Skip empty paragraphs
55
+ data.append({"text": paragraph.strip()})
56
+ return Dataset.from_list(data)
57
+
58
+ # Load and preprocess dataset for fine-tuning
59
+ dataset = load_data(book_paths)
60
+
61
+ # Load pretrained model and tokenizer from Hugging Face
62
+ model_name = "gpt2" # Replace with a larger model if needed and feasible
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+
65
+ # Set the pad_token to be the same as eos_token (fix for missing padding token)
66
+ tokenizer.pad_token = tokenizer.eos_token
67
+
68
+ model = AutoModelForCausalLM.from_pretrained(model_name)
69
+
70
+ # Tokenize data and create labels (shifted input for causal language modeling)
71
+ def tokenize_function(examples):
72
+ # Tokenize the input text
73
+ encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
74
+
75
+ # Create labels by shifting the input ids by one position (for causal LM)
76
+ labels = encodings["input_ids"].copy()
77
+ labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
78
+
79
+ # Return the encodings with labels
80
+ encodings["labels"] = labels
81
+ return encodings
82
+
83
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
84
+
85
+ # Split dataset into train and eval (explicit split for better validation)
86
+ train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
87
+ train_dataset = train_test_split["train"]
88
+ eval_dataset = train_test_split["test"]
89
+
90
+ # Define training arguments
91
+ training_args = TrainingArguments(
92
+ output_dir="./results", # Output directory for model and logs
93
+ eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy
94
+ learning_rate=2e-5, # Learning rate
95
+ per_device_train_batch_size=8, # Batch size for training
96
+ per_device_eval_batch_size=8, # Batch size for evaluation
97
+ num_train_epochs=3, # Number of training epochs
98
+ weight_decay=0.01, # Weight decay for regularization
99
+ )
100
+
101
+ # Initialize the Trainer
102
+ trainer = Trainer(
103
+ model=model,
104
+ args=training_args,
105
+ train_dataset=train_dataset,
106
+ eval_dataset=eval_dataset, # Pass eval dataset for evaluation
107
+ tokenizer=tokenizer, # Provide tokenizer for model inference
108
+ )
109
+
110
+ # Fine-tune the model
111
+ trainer.train()
112
+
113
+ # Save the model after fine-tuning
114
+ model.save_pretrained("./fine_tuned_model")
115
+ tokenizer.save_pretrained("./fine_tuned_model")
116
+
117
+ # Step 4: Define response function with emergency keyword check
118
+ def get_response(user_input):
119
+ # Check for emergency/distress keywords
120
+ distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"]
121
+ is_distress = any(word in user_input.lower() for word in distress_keywords)
122
+
123
+ # Use Groq API for generating a response
124
+ chat_completion = client.chat.completions.create(
125
+ messages=[{"role": "user", "content": user_input}],
126
+ model="llama3-8b-8192", # Or replace with another model
127
+ )
128
+ response = chat_completion.choices[0].message.content
129
+
130
+ # Append emergency message if distress keywords are detected
131
+ if is_distress:
132
+ response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]."
133
+
134
+ return response
135
+
136
+ # Step 5: Set up Gradio Interface
137
+ import gradio as gr
138
+
139
+ def chatbot_interface(input_text):
140
+ return get_response(input_text)
141
+
142
+ # Launch the Gradio app
143
+ gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()