ShieldX commited on
Commit
804edcf
·
verified ·
1 Parent(s): 2eb4b74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -34
app.py CHANGED
@@ -1,53 +1,68 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, GenerationConfig
 
4
 
5
  tokenizer = AutoTokenizer.from_pretrained("ShieldX/manovyadh-1.1B-v1")
6
  model = AutoModelForCausalLM.from_pretrained("ShieldX/manovyadh-1.1B-v1")
7
- model = model.to('cpu')
 
 
 
 
 
 
 
 
8
 
9
  title = "🌱 ManoVyadh 🌱"
10
  description = "Mental Health Counselling Chatbot"
11
  examples = ["I have been feeling more and more down for over a month. I have started having trouble sleeping due to panic attacks, but they are almost never triggered by something that I know of.", "I self-harm, and I stop for a while. Then when I see something sad or depressing, I automatically want to self-harm.", "I am feeling sad for my friend's divorce"]
12
 
 
 
 
 
 
 
 
 
13
  def predict(message, history):
14
- def formatted_prompt(question)-> str:
15
- sysp = "You are an AI assistant that helps people cope with stress and improve their mental health. User will tell you about their feelings and challenges. Your task is to listen empathetically and offer helpful suggestions. While responding, think about the user’s needs and goals and show compassion and support."
16
- return f"<|im_start|>system\n{sysp}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant:"
17
 
18
- # history_transformer_format = history + [[formatted_prompt(message), "", message]]
 
19
 
20
- # messages = "".join(["".join([f"\n user: {item[2]}, \n assistant: {item[1]}"]) #curr_system_message +
21
- # for item in history_transformer_format])
22
 
23
- messages = formatted_prompt(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- inputs = tokenizer([messages], return_tensors="pt").to("cpu")
 
 
 
 
 
 
 
 
26
 
27
- # streamer = TextStreamer(tokenizer)
28
-
29
- generation_config = GenerationConfig(
30
- penalty_alpha=0.6,
31
- early_stopping=True,
32
- num_beams=2,
33
- do_sample=True,
34
- top_k=5,
35
- temperature=0.7,
36
- repetition_penalty=1.2,
37
- max_new_tokens=64,
38
- eos_token_id=tokenizer.eos_token_id,
39
- pad_token_id=tokenizer.eos_token_id
40
- )
41
-
42
- outputs = model.generate(**inputs, generation_config=generation_config)
43
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- return response[len(formatted_prompt(message)):]
45
-
46
- # partial_message = ""
47
- # for i in response:
48
- # if response!="" or i!="":
49
- # partial_message+=i
50
- # yield partial_message
51
 
52
  gr.ChatInterface(
53
  predict,
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ from threading import Thread
5
 
6
  tokenizer = AutoTokenizer.from_pretrained("ShieldX/manovyadh-1.1B-v1")
7
  model = AutoModelForCausalLM.from_pretrained("ShieldX/manovyadh-1.1B-v1")
8
+
9
+ # Check for GPU availability
10
+ if torch.cuda.is_available():
11
+ device = "cuda"
12
+ else:
13
+ device = "cpu"
14
+
15
+ # Move model and inputs to the GPU (if available)
16
+ model.to(device)
17
 
18
  title = "🌱 ManoVyadh 🌱"
19
  description = "Mental Health Counselling Chatbot"
20
  examples = ["I have been feeling more and more down for over a month. I have started having trouble sleeping due to panic attacks, but they are almost never triggered by something that I know of.", "I self-harm, and I stop for a while. Then when I see something sad or depressing, I automatically want to self-harm.", "I am feeling sad for my friend's divorce"]
21
 
22
+ class StopOnTokens(StoppingCriteria):
23
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
+ stop_ids = [29, 0]
25
+ for stop_id in stop_ids:
26
+ if input_ids[0][-1] == stop_id:
27
+ return True
28
+ return False
29
+
30
  def predict(message, history):
 
 
 
31
 
32
+ history_transformer_format = history + [[message, ""]]
33
+ stop = StopOnTokens()
34
 
35
+ sys_msg = """###SYSTEM: You are an AI assistant that helps people cope with stress and improve their mental health. User will tell you about their feelings and challenges. Your task is to listen empathetically and offer helpful suggestions. While responding, think about the user’s needs and goals and show compassion and support"""
 
36
 
37
+ messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message +
38
+ for item in history_transformer_format])
39
+
40
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
41
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
42
+ generate_kwargs = dict(
43
+ model_inputs,
44
+ streamer=streamer,
45
+ max_new_tokens=256,
46
+ do_sample=True,
47
+ top_p=0.95,
48
+ top_k=1000,
49
+ temperature=1.0,
50
+ num_beams=1,
51
+ stopping_criteria=StoppingCriteriaList([stop])
52
+ )
53
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
54
+ t.start()
55
 
56
+ partial_message = ""
57
+ for new_token in streamer:
58
+ if new_token != '#':
59
+ partial_message += new_token
60
+ yield partial_message
61
+ else:
62
+ print("new token = #")
63
+ partial_message += new_token
64
+ yield partial_message
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  gr.ChatInterface(
68
  predict,