Update app.py
Browse files
app.py
CHANGED
|
@@ -57,8 +57,6 @@ def generate_interactive(
|
|
| 57 |
):
|
| 58 |
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
|
| 59 |
input_length = len(inputs['input_ids'][0])
|
| 60 |
-
for k, v in inputs.items():
|
| 61 |
-
inputs[k] = v.cuda()
|
| 62 |
input_ids = inputs['input_ids']
|
| 63 |
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
| 64 |
if generation_config is None:
|
|
@@ -184,7 +182,7 @@ def load_model():
|
|
| 184 |
model_dir = 'internlm/internlm2-chat-1_8b'
|
| 185 |
model = (AutoModelForCausalLM.from_pretrained(
|
| 186 |
model_dir,
|
| 187 |
-
trust_remote_code=True).to(torch.bfloat16)
|
| 188 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 189 |
model_dir,
|
| 190 |
trust_remote_code=True)
|
|
@@ -232,7 +230,6 @@ def combine_history(prompt):
|
|
| 232 |
|
| 233 |
|
| 234 |
def main():
|
| 235 |
-
# torch.cuda.empty_cache()
|
| 236 |
print('load model begin.')
|
| 237 |
model, tokenizer = load_model()
|
| 238 |
print('load model end.')
|
|
@@ -278,7 +275,7 @@ def main():
|
|
| 278 |
'role': 'robot',
|
| 279 |
'content': cur_response, # pylint: disable=undefined-loop-variable
|
| 280 |
})
|
| 281 |
-
|
| 282 |
|
| 283 |
|
| 284 |
if __name__ == '__main__':
|
|
|
|
| 57 |
):
|
| 58 |
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
|
| 59 |
input_length = len(inputs['input_ids'][0])
|
|
|
|
|
|
|
| 60 |
input_ids = inputs['input_ids']
|
| 61 |
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
| 62 |
if generation_config is None:
|
|
|
|
| 182 |
model_dir = 'internlm/internlm2-chat-1_8b'
|
| 183 |
model = (AutoModelForCausalLM.from_pretrained(
|
| 184 |
model_dir,
|
| 185 |
+
trust_remote_code=True).to(torch.bfloat16))
|
| 186 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 187 |
model_dir,
|
| 188 |
trust_remote_code=True)
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
def main():
|
|
|
|
| 233 |
print('load model begin.')
|
| 234 |
model, tokenizer = load_model()
|
| 235 |
print('load model end.')
|
|
|
|
| 275 |
'role': 'robot',
|
| 276 |
'content': cur_response, # pylint: disable=undefined-loop-variable
|
| 277 |
})
|
| 278 |
+
|
| 279 |
|
| 280 |
|
| 281 |
if __name__ == '__main__':
|