adi2606 commited on
Commit
dbb3b1f
·
verified ·
1 Parent(s): 22209ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -10,24 +10,23 @@ model = AutoModelForCausalLM.from_pretrained("adi2606/MenstrualQA").to(device)
10
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
11
 
12
  # Function to generate a response from the chatbot
13
- def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1, max_input_length: int = 256) -> str:
14
- inputs = tokenizer(
15
- message,
16
- return_tensors="pt",
17
- padding=True,
18
- truncation=True,
19
- max_length=max_input_length
 
20
  ).to(device)
21
 
22
  # Generate the response
23
  output = model.generate(
24
- inputs['input_ids'],
25
- attention_mask=inputs['attention_mask'],
26
  max_length=512,
27
  temperature=temperature,
28
  repetition_penalty=repetition_penalty,
29
- do_sample=True,
30
- pad_token_id=tokenizer.eos_token_id
31
  )
32
 
33
  # Decode the generated output
 
10
  tokenizer = AutoTokenizer.from_pretrained("adi2606/MenstrualQA")
11
 
12
  # Function to generate a response from the chatbot
13
+ def generate_response(message: str, temperature: float = 0.4, repetition_penalty: float = 1.1) -> str:
14
+ # Apply the chat template and convert to PyTorch tensors
15
+ messages = [
16
+ {"role": "system", "content": "You are a helpful assistant."},
17
+ {"role": "user", "content": message}
18
+ ]
19
+ input_ids = tokenizer.apply_chat_template(
20
+ messages, add_generation_prompt=True, return_tensors="pt"
21
  ).to(device)
22
 
23
  # Generate the response
24
  output = model.generate(
25
+ input_ids,
 
26
  max_length=512,
27
  temperature=temperature,
28
  repetition_penalty=repetition_penalty,
29
+ do_sample=True
 
30
  )
31
 
32
  # Decode the generated output