OpenCALM-LARGE
Model Description
OpenCALM is a suite of decoder-only language models pre-trained on Japanese datasets, developed by CyberAgent, Inc.
このモデルはpeftを用いてopen-calm-largeをLoRAファインチューニングしたものです。
Usage
pytorchおよびtransformers, peftをインストールして下記コードを実行してください
(pip install torch, transformers, peft)
and please execute this code.
下記コードに関しては
npakaさんの記事(https://note.com/npaka/n/na5b8e6f749ce)
を参考にさせて頂きました。 感謝致します。
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "cyberagent/open-calm-large"
lora_weights = "Mizuiro-sakura/open-calm-large-finetuned-databricks-dolly"
# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
model_name
)
# トークンナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name)
# LoRAモデルの準備
model = PeftModel.from_pretrained(
model,
lora_weights,
adapter_name=lora_weights
)
# 評価モード
model.eval()
# プロンプトテンプレートの準備
def generate_prompt(data_point):
if data_point["input"]:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Input:
{data_point["input"]}
### Response:"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{data_point["instruction"]}
### Response:"""
# テキスト生成関数の定義
def generate(instruction,input=None,maxTokens=256):
# 推論
prompt = generate_prompt({'instruction':instruction,'input':input})
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=maxTokens,
do_sample=True,
temperature=0.7,
top_p=0.75,
top_k=40,
no_repeat_ngram_size=2,
)
outputs = outputs[0].tolist()
# EOSトークンにヒットしたらデコード完了
if tokenizer.eos_token_id in outputs:
eos_index = outputs.index(tokenizer.eos_token_id)
else:
eos_index = len(outputs)
decoded = tokenizer.decode(outputs[:eos_index])
# レスポンス内容のみ抽出
sentinel = "### Response:"
sentinelLoc = decoded.find(sentinel)
if sentinelLoc >= 0:
print(decoded[sentinelLoc+len(sentinel):])
else:
print('Warning: Expected prompt template to be emitted. Ignoring output.')
generate("自然言語処理とは?")
Model Details
Model | Params | Layers | Dim | Heads | Dev ppl |
---|---|---|---|---|---|
cyberagent/open-calm-small | 160M | 12 | 768 | 12 | 19.7 |
cyberagent/open-calm-medium | 400M | 24 | 1024 | 16 | 13.8 |
cyberagent/open-calm-large | 830M | 24 | 1536 | 16 | 11.3 |
cyberagent/open-calm-1b | 1.4B | 24 | 2048 | 16 | 10.3 |
cyberagent/open-calm-3b | 2.7B | 32 | 2560 | 32 | 9.7 |
cyberagent/open-calm-7b | 6.8B | 32 | 4096 | 32 | 8.2 |
- Developed by: CyberAgent, Inc.
- Model type: Transformer-based Language Model
- Language: Japanese
- Library: GPT-NeoX
- License: OpenCALM is licensed under the Creative Commons Attribution-ShareAlike 4.0 International License (CC BY-SA 4.0). When using this model, please provide appropriate credit to CyberAgent, Inc.
- Example (en): This model is a fine-tuned version of OpenCALM-XX developed by CyberAgent, Inc. The original model is released under the CC BY-SA 4.0 license, and this model is also released under the same CC BY-SA 4.0 license. For more information, please visit: https://creativecommons.org/licenses/by-sa/4.0/
- Example (ja): 本モデルは、株式会社サイバーエージェントによるOpenCALM-XXをファインチューニングしたものです。元のモデルはCC BY-SA 4.0ライセンスのもとで公開されており、本モデルも同じくCC BY-SA 4.0ライセンスで公開します。詳しくはこちらをご覧ください: https://creativecommons.org/licenses/by-sa/4.0/
Training Dataset
- Wikipedia (ja)
- Common Crawl (ja)
Author
Citations
@software{gpt-neox-library,
title = {{GPT-NeoX: Large Scale Autoregressive Language Modeling in PyTorch}},
author = {Andonian, Alex and Anthony, Quentin and Biderman, Stella and Black, Sid and Gali, Preetham and Gao, Leo and Hallahan, Eric and Levy-Kramer, Josh and Leahy, Connor and Nestler, Lucas and Parker, Kip and Pieler, Michael and Purohit, Shivanshu and Songz, Tri and Phil, Wang and Weinbach, Samuel},
url = {https://www.github.com/eleutherai/gpt-neox},
doi = {10.5281/zenodo.5879544},
month = {8},
year = {2021},
version = {0.0.1},
}
- Downloads last month
- 8