nafisneehal commited on
Commit
76d6bf4
·
verified ·
1 Parent(s): 8312e83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -18
app.py CHANGED
@@ -90,32 +90,34 @@ def generate_response(model_name, system_instruction, user_input):
90
  prompt = f"""### Instruction:
91
  {system_instruction}
92
  Remember to ALWAYS format your response as valid JSON.
93
-
94
  ### Input:
95
  {user_input}
96
-
97
  ### Response:
98
  {{""" # Note the opening curly brace to hint JSON response
99
 
100
- inputs = model_manager.current_tokenizer([prompt], return_tensors="pt").to(model_manager.device)
101
-
102
- # Generation configuration optimized for JSON output
103
- meta_config = {
104
- "do_sample": False,
105
- "temperature": 0.0,
106
- "max_new_tokens": 512,
107
- "repetition_penalty": 1.1,
108
- "use_cache": True,
109
- "pad_token_id": model_manager.current_tokenizer.eos_token_id,
110
- "eos_token_id": model_manager.current_tokenizer.eos_token_id
111
- }
112
- generation_config = GenerationConfig(**meta_config)
113
-
114
- # Generate response
115
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  with torch.no_grad():
117
  outputs = model_manager.current_model.generate(
118
- **inputs,
 
119
  generation_config=generation_config
120
  )
121
  decoded_output = model_manager.current_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
 
90
  prompt = f"""### Instruction:
91
  {system_instruction}
92
  Remember to ALWAYS format your response as valid JSON.
 
93
  ### Input:
94
  {user_input}
 
95
  ### Response:
96
  {{""" # Note the opening curly brace to hint JSON response
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  try:
99
+ # Ensure inputs are on the correct device
100
+ inputs = model_manager.current_tokenizer([prompt], return_tensors="pt")
101
+ # Move input_ids and attention_mask to the same device as the model
102
+ inputs = {k: v.to(model_manager.device) for k, v in inputs.items()}
103
+
104
+ # Generation configuration optimized for JSON output
105
+ meta_config = {
106
+ "do_sample": False,
107
+ "temperature": 0.0,
108
+ "max_new_tokens": 512,
109
+ "repetition_penalty": 1.2,
110
+ "use_cache": True,
111
+ "pad_token_id": model_manager.current_tokenizer.eos_token_id,
112
+ "eos_token_id": model_manager.current_tokenizer.eos_token_id
113
+ }
114
+ generation_config = GenerationConfig(**meta_config)
115
+
116
+ # Generate response
117
  with torch.no_grad():
118
  outputs = model_manager.current_model.generate(
119
+ input_ids=inputs['input_ids'],
120
+ attention_mask=inputs['attention_mask'],
121
  generation_config=generation_config
122
  )
123
  decoded_output = model_manager.current_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]