Mattral commited on
Commit
36eb6b1
·
verified ·
1 Parent(s): ff96349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -24
app.py CHANGED
@@ -1,18 +1,32 @@
 
1
  import gradio as gr
2
  from typing import Iterator, List, Tuple
3
-
4
- # Mock model function to simulate response generation
5
- def mock_run(
6
- message: str,
7
- chat_history: List[Tuple[str, str]],
8
- system_prompt: str,
9
- max_new_tokens: int,
10
- temperature: float,
11
- top_p: float,
12
- top_k: int,
13
- ) -> Iterator[str]:
14
- response = f"Mock response to: {message}"
15
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
18
 
@@ -22,7 +36,7 @@ MAX_INPUT_TOKEN_LENGTH = 4000
22
 
23
  DESCRIPTION = """
24
  # Simple Healthcare Chatbot
25
- ### Powered by a mock model
26
  """
27
 
28
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
@@ -52,21 +66,26 @@ def generate(
52
  raise ValueError("Max new tokens exceeded")
53
 
54
  history = history_with_input[:-1]
55
- generator = mock_run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
56
- try:
57
- first_response = next(generator)
58
- yield history + [(message, first_response)]
59
- except StopIteration:
60
- yield history + [(message, "")]
61
- for response in generator:
62
- yield history + [(message, response)]
 
 
 
 
 
63
 
64
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
65
- input_token_length = len(message) + sum(len(msg) for msg, _ in chat_history)
66
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
67
  raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
68
 
69
- with gr.Blocks() as demo:
70
  gr.Markdown(DESCRIPTION)
71
  gr.Button("Duplicate Space for private use", elem_id="duplicate-button")
72
 
 
1
+
2
  import gradio as gr
3
  from typing import Iterator, List, Tuple
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from peft import PeftConfig, PeftModel
7
+
8
+ base_model = "mistralai/Mistral-7B-Instruct-v0.2"
9
+ adapter = "GRMenon/mental-health-mistral-7b-instructv0.2-finetuned-V2"
10
+
11
+ # Load tokenizer
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ base_model,
14
+ add_bos_token=True,
15
+ trust_remote_code=True,
16
+ padding_side='left'
17
+ )
18
+
19
+ # Create peft model using base_model and finetuned adapter
20
+ config = PeftConfig.from_pretrained(adapter)
21
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,
22
+ load_in_4bit=True,
23
+ device_map='auto',
24
+ torch_dtype='auto')
25
+ model = PeftModel.from_pretrained(model, adapter)
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model.to(device)
29
+ model.eval()
30
 
31
  DEFAULT_SYSTEM_PROMPT = "You are Phoenix AI Healthcare. You are professional, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI about Healtcare and Wellness. You can communicate in different languages equally well."
32
 
 
36
 
37
  DESCRIPTION = """
38
  # Simple Healthcare Chatbot
39
+ ### Powered by Mistral-7B with Healthcare Fine-Tuning
40
  """
41
 
42
  def clear_and_save_textbox(message: str) -> tuple[str, str]:
 
66
  raise ValueError("Max new tokens exceeded")
67
 
68
  history = history_with_input[:-1]
69
+ conversation = [{"role": "system", "content": system_prompt}] + \
70
+ [{"role": "user", "content": user_input} for user_input, _ in history] + \
71
+ [{"role": "user", "content": message}]
72
+ input_ids = tokenizer.apply_chat_template(conversation=conversation,
73
+ tokenize=True,
74
+ add_generation_prompt=True,
75
+ return_tensors='pt').to(device)
76
+ output_ids = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens,
77
+ do_sample=True, pad_token_id=tokenizer.pad_token_id)
78
+ response = tokenizer.batch_decode(output_ids.detach().cpu().numpy(), skip_special_tokens=True)
79
+ response_text = response[0]
80
+
81
+ yield history + [(message, response_text)]
82
 
83
  def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
84
+ input_token_length = len(tokenizer.encode(message)) + sum(len(tokenizer.encode(msg)) for msg, _ in chat_history)
85
  if input_token_length > MAX_INPUT_TOKEN_LENGTH:
86
  raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
87
 
88
+ with gr.Blocks(css="./styles/style.css") as demo: # Link to CSS file
89
  gr.Markdown(DESCRIPTION)
90
  gr.Button("Duplicate Space for private use", elem_id="duplicate-button")
91