ahmed792002 commited on
Commit
f1cb50f
·
verified ·
1 Parent(s): 3a4c449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import T5Tokenizer, T5ForConditionalGeneration
7
  # Load pre-trained model and tokenizer
8
  tokenizer = T5Tokenizer.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
9
  model = T5ForConditionalGeneration.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
 
10
 
11
  # Function to clean input text
12
  def clean_text(text):
@@ -17,14 +18,16 @@ def clean_text(text):
17
  return text
18
 
19
  # Chatbot function
20
- def chatbot(query):
21
  query = clean_text(query)
22
  input_ids = tokenizer(query, return_tensors="pt", max_length=256, truncation=True)
23
  inputs = {key: value.to(device) for key, value in input_ids.items()}
24
  outputs = model.generate(
25
  input_ids["input_ids"],
26
- max_length=1024,
27
  num_beams=5,
 
 
28
  early_stopping=True
29
  )
30
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -40,4 +43,4 @@ demo = gr.ChatInterface(
40
  )
41
 
42
  if __name__ == "__main__":
43
- demo.launch()
 
7
  # Load pre-trained model and tokenizer
8
  tokenizer = T5Tokenizer.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
9
  model = T5ForConditionalGeneration.from_pretrained("ahmed792002/Finetuning_T5_HealthCare_Chatbot")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
  # Function to clean input text
13
  def clean_text(text):
 
18
  return text
19
 
20
  # Chatbot function
21
+ def chatbot(query, history, system_message):
22
  query = clean_text(query)
23
  input_ids = tokenizer(query, return_tensors="pt", max_length=256, truncation=True)
24
  inputs = {key: value.to(device) for key, value in input_ids.items()}
25
  outputs = model.generate(
26
  input_ids["input_ids"],
27
+ max_length=1024, # Adjust this as needed for your use case
28
  num_beams=5,
29
+ temperature=0.7, # Adjust this as needed
30
+ top_p=0.95, # Adjust this as needed
31
  early_stopping=True
32
  )
33
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
43
  )
44
 
45
  if __name__ == "__main__":
46
+ demo.launch(share=True) # Set `share=True` to create a public link