ahmed792002 commited on
Commit
3a4c449
·
verified ·
1 Parent(s): 93ff22f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -4,18 +4,22 @@ import re
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
 
7
  tokenizer = T5Tokenizer.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
8
  model = T5ForConditionalGeneration.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
9
 
 
10
  def clean_text(text):
11
  text = re.sub(r'\r\n', ' ', text) # Remove carriage returns and line breaks
12
  text = re.sub(r'\s+', ' ', text) # Remove extra spaces
13
  text = re.sub(r'<.*?>', '', text) # Remove any XML tags
14
  text = text.strip().lower() # Strip and convert to lower case
15
  return text
 
 
16
  def chatbot(query):
17
  query = clean_text(query)
18
- input_ids = tokenizer(query,return_tensors="pt",max_length=256,truncation=True)
19
  inputs = {key: value.to(device) for key, value in input_ids.items()}
20
  outputs = model.generate(
21
  input_ids["input_ids"],
@@ -24,7 +28,7 @@ def chatbot(query):
24
  early_stopping=True
25
  )
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
27
-
28
  """
29
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
30
  """
@@ -35,6 +39,5 @@ demo = gr.ChatInterface(
35
  ],
36
  )
37
 
38
-
39
  if __name__ == "__main__":
40
  demo.launch()
 
4
  import torch
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
+ # Load pre-trained model and tokenizer
8
  tokenizer = T5Tokenizer.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
9
  model = T5ForConditionalGeneration.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
10
 
11
+ # Function to clean input text
12
  def clean_text(text):
13
  text = re.sub(r'\r\n', ' ', text) # Remove carriage returns and line breaks
14
  text = re.sub(r'\s+', ' ', text) # Remove extra spaces
15
  text = re.sub(r'<.*?>', '', text) # Remove any XML tags
16
  text = text.strip().lower() # Strip and convert to lower case
17
  return text
18
+
19
+ # Chatbot function
20
  def chatbot(query):
21
  query = clean_text(query)
22
+ input_ids = tokenizer(query, return_tensors="pt", max_length=256, truncation=True)
23
  inputs = {key: value.to(device) for key, value in input_ids.items()}
24
  outputs = model.generate(
25
  input_ids["input_ids"],
 
28
  early_stopping=True
29
  )
30
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
  """
33
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
34
  """
 
39
  ],
40
  )
41
 
 
42
  if __name__ == "__main__":
43
  demo.launch()