ahmed792002 commited on
Commit
915193a
·
verified ·
1 Parent(s): f1cb50f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -25
app.py CHANGED
@@ -1,45 +1,55 @@
1
- import sentencepiece
2
  import gradio as gr
3
- import re
4
  import torch
5
- from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
  # Load pre-trained model and tokenizer
8
- tokenizer = T5Tokenizer.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
9
- model = T5ForConditionalGeneration.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
11
 
12
  # Function to clean input text
13
  def clean_text(text):
14
- text = re.sub(r'\r\n', ' ', text) # Remove carriage returns and line breaks
15
- text = re.sub(r'\s+', ' ', text) # Remove extra spaces
16
- text = re.sub(r'<.*?>', '', text) # Remove any XML tags
17
- text = text.strip().lower() # Strip and convert to lower case
18
- return text
19
 
20
  # Chatbot function
21
- def chatbot(query, history, system_message):
 
 
 
 
22
  query = clean_text(query)
23
- input_ids = tokenizer(query, return_tensors="pt", max_length=256, truncation=True)
24
- inputs = {key: value.to(device) for key, value in input_ids.items()}
25
- outputs = model.generate(
26
- input_ids["input_ids"],
27
- max_length=1024, # Adjust this as needed for your use case
28
- num_beams=5,
29
- temperature=0.7, # Adjust this as needed
30
- top_p=0.95, # Adjust this as needed
31
- early_stopping=True
 
 
 
 
32
  )
33
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
34
 
35
- """
36
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
37
- """
38
  demo = gr.ChatInterface(
39
  chatbot,
40
  additional_inputs=[
41
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
 
 
 
 
42
  ],
 
 
43
  )
44
 
45
  if __name__ == "__main__":
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import gradio as gr
 
3
  import torch
 
4
 
5
  # Load pre-trained model and tokenizer
6
+ tokenizer = AutoTokenizer.from_pretrained("ahmed792002/alzheimers_memory_support_ai")
7
+ model = AutoModelForCausalLM.from_pretrained("ahmed792002/alzheimers_memory_support_ai")
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ model.to(device) # Send the model to the correct device
10
 
11
  # Function to clean input text
12
  def clean_text(text):
13
+ return text.strip() # Simply remove leading/trailing spaces
 
 
 
 
14
 
15
  # Chatbot function
16
+ def chatbot(query, history, system_message, max_length, temperature, top_k, top_p):
17
+ """
18
+ Processes a user query through the specified model to generate a response.
19
+ """
20
+ # Clean the input query
21
  query = clean_text(query)
22
+
23
+ # Tokenize input query
24
+ input_ids = tokenizer.encode(query, return_tensors="pt").to(device)
25
+
26
+ # Generate text using the model
27
+ final_outputs = model.generate(
28
+ input_ids,
29
+ do_sample=True,
30
+ max_length=int(max_length), # Convert max_length to integer
31
+ temperature=float(temperature), # Convert temperature to float
32
+ top_k=int(top_k), # Convert top_k to integer
33
+ top_p=float(top_p), # Convert top_p to float
34
+ pad_token_id=tokenizer.pad_token_id,
35
  )
36
+
37
+ # Decode generated text
38
+ response = tokenizer.decode(final_outputs[0], skip_special_tokens=True)
39
+ return response
40
 
41
+ # Gradio ChatInterface
 
 
42
  demo = gr.ChatInterface(
43
  chatbot,
44
  additional_inputs=[
45
+ gr.Textbox(value="You are a friendly chatbot.", label="System message"),
46
+ gr.Slider(128, 1024, value=256, step=64, label="Max Length"), # Slider for max_length
47
+ gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature"), # Slider for temperature
48
+ gr.Slider(1, 100, value=50, step=1, label="Top-K"), # Slider for top_k
49
+ gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P"), # Slider for top_p
50
  ],
51
+ title="Custom Alzheimer's Memory Support AI",
52
+ description="This chatbot uses the fine-tuned model 'ahmed792002/alzheimers_memory_support_ai'. Customize settings like max length, temperature, top-k, and top-p for better results.",
53
  )
54
 
55
  if __name__ == "__main__":