File size: 4,931 Bytes
d17e1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54c4de8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
import torch
import re, os, warnings
from langchain import PromptTemplate, LLMChain
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
warnings.filterwarnings("ignore")

# initialize and load PEFT model and tokenizer
def init_model_and_tokenizer(PEFT_MODEL):
  config = PeftConfig.from_pretrained(PEFT_MODEL)
  bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
  )

  peft_base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
  )

  peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)

  peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
  peft_tokenizer.pad_token = peft_tokenizer.eos_token

  return peft_model, peft_tokenizer

# custom LLM chain to generate answer from PEFT model for each query
def init_llm_chain(peft_model, peft_tokenizer):
    class CustomLLM(LLM):
        def _call(self, prompt: str, stop=None, run_manager=None) -> str:
            device = "cuda:0"
            peft_encoding = peft_tokenizer(prompt, return_tensors="pt").to(device)
            peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=128, pad_token_id = peft_tokenizer.eos_token_id, \
                                                                                                                     eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
                                                                                                                     temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
            peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
            return peft_text_output

        @property
        def _llm_type(self) -> str:
            return "custom"

    llm = CustomLLM()

    template = """Answer the following question truthfully.
    If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
    If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.

    Example Format:
    : question here
    : answer here

    Begin!

    : {query}
    :"""

    prompt = PromptTemplate(template=template, input_variables=["query"])
    llm_chain = LLMChain(prompt=prompt, llm=llm)

    return llm_chain

def user(user_message, history):
    return "", history + [[user_message, None]]

def bot(history):
      if len(history) >= 2:
        query = history[-2][0] + "\n" + history[-2][1] + "\nHere, is the next QUESTION: " + history[-1][0]
      else:
        query = history[-1][0]

      bot_message = llm_chain.run(query)
      bot_message = post_process_chat(bot_message, query)

      history[-1][1] = ""
      history[-1][1] += bot_message
      return history

def post_process_chat(bot_message, query):
    # Find the position of ": {query}" in the bot_response
    query_position = bot_message.find(f": {query}")

    if query_position != -1:
        # Extract the part of the response starting from ": {query}"
        response_part = bot_message[query_position + len(f": {query}"):].strip()

    last_period_position = response_part.rfind(".")

    if last_period_position != -1:
        # Extract the part of the response up to the last period
        new_response_part = response_part[:last_period_position + 1].strip()
        return new_response_part

    # Return the original response if ": {query}" is not found
    return bot_message

model = "uzairsiddiqui/falcon-7b-sharded-bf16-finetuned-mental-health-conversational"
peft_model, peft_tokenizer = init_model_and_tokenizer(PEFT_MODEL = model)

with gr.Blocks() as demo:
    gr.HTML("""Welcome to Mental Health Conversational AI""")
    gr.Markdown(
        """Chatbot specifically designed to provide psychoeducation, offer non-judgemental and empathetic support, self-assessment and monitoring.
        Get instant response for any mental health related queries. If the chatbot seems you need external support, then it will respond appropriately."""
    )

    chatbot = gr.Chatbot()
    query = gr.Textbox(label="Type your query here, then press 'enter' and scroll up for response")
    clear = gr.Button(value="Clear Chat History!")

    llm_chain = init_llm_chain(peft_model, peft_tokenizer)

    query.submit(user, [query, chatbot], [query, chatbot], queue=False).then(bot, chatbot, chatbot)
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()