Ultimate_96B_Chat / chat_bot.py
sillynugget's picture
Upload 4 files
17ffc4f
raw
history blame
1.43 kB
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# ใƒขใƒ‡ใƒซใจใƒˆใƒผใ‚ฏใƒŠใ‚คใ‚ถใƒผใฎใƒญใƒผใƒ‰
model_name = "cyberagent/open-calm-large"
# model_name = "cyberagent/open-calm-3b"
model = AutoModelForCausalLM.from_pretrained(model_name).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ใƒใƒฃใƒƒใƒˆใฎ้–‹ๅง‹
print("ใƒœใƒƒใƒˆ: ใ“ใ‚“ใซใกใฏ! ไฝ•ใ‹่ณชๅ•ใŒใ‚ใ‚Šใพใ™ใ‹?")
while True:
# ใƒฆใƒผใ‚ถใ‹ใ‚‰ใฎๅ…ฅๅŠ›ใ‚’ๅ—ใ‘ๅ–ใ‚‹
user_input = input("ใ‚ใชใŸ: ")
# ใƒฆใƒผใ‚ถใฎๅ…ฅๅŠ›ใ‚’ใ‚จใƒณใ‚ณใƒผใƒ‰ใ—ใฆใƒ†ใƒณใ‚ฝใƒซใซๅค‰ๆ›
input_ids = tokenizer.encode(user_input, return_tensors='pt')
attention_mask = (input_ids != tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0).int()
# ๅ…ฅๅŠ›ใƒ‡ใƒผใ‚ฟใ‚’้ฉๅˆ‡ใชใƒ‡ใƒใ‚คใ‚นใซ้€ใ‚‹
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
# ใƒขใƒ‡ใƒซใซใ‚ˆใ‚‹ๅฟœ็ญ”ใฎ็”Ÿๆˆ
output = model.generate(input_ids, attention_mask=attention_mask, max_length=300, num_return_sequences=1, no_repeat_ngram_size=2, pad_token_id=model.config.eos_token_id)
# ็”Ÿๆˆใ•ใ‚ŒใŸใƒ†ใ‚ญใ‚นใƒˆใฎใƒ‡ใ‚ณใƒผใƒ‰
output_text = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
print("ใƒœใƒƒใƒˆ: " + output_text)