|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- squarelike/sharegpt_deepl_ko_translation |
|
language: |
|
- en |
|
- ko |
|
pipeline_tag: translation |
|
--- |
|
|
|
# Gugugo-koen-7B-V1.1 |
|
Detail repo: [https://github.com/jwj7140/Gugugo](https://github.com/jwj7140/Gugugo) |
|
![Gugugo](./logo.png) |
|
|
|
**Base Model**: [Llama-2-ko-7b](https://huggingface.co/beomi/llama-2-ko-7b) |
|
|
|
**Training Dataset**: [sharegpt_deepl_ko_translation](https://huggingface.co/datasets/squarelike/sharegpt_deepl_ko_translation). |
|
|
|
I trained with 1x A6000 GPUs for 90 hours. |
|
|
|
## **Prompt Template** |
|
**KO->EN** |
|
``` |
|
### νκ΅μ΄: {sentence}</λ> |
|
### μμ΄: |
|
``` |
|
**EN->KO** |
|
``` |
|
### μμ΄: {sentence}</λ> |
|
### νκ΅μ΄: |
|
``` |
|
|
|
## **Implementation Code** |
|
```python |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
import torch |
|
repo = "squarelike/Gugugo-koen-7B-V1.1" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
repo, |
|
load_in_4bit=True |
|
device_map='auto' |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(repo) |
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops = [], encounters=1): |
|
super().__init__() |
|
self.stops = [stop for stop in stops] |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
return True |
|
|
|
return False |
|
|
|
stop_words_ids = torch.tensor([[829, 45107, 29958], [1533, 45107, 29958], [829, 45107, 29958], [21106, 45107, 29958]]) |
|
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
def gen(lan="en", x=""): |
|
if (lan == "ko"): |
|
prompt = f"### νκ΅μ΄: {x}</λ>\n### μμ΄:" |
|
else: |
|
prompt = f"### μμ΄: {x}</λ>\n### νκ΅μ΄:" |
|
gened = model.generate( |
|
**tokenizer( |
|
prompt, |
|
return_tensors='pt', |
|
return_token_type_ids=False |
|
), |
|
max_new_tokens=1000, |
|
temperature=0.1, |
|
no_repeat_ngram_size=10, |
|
early_stopping=True, |
|
do_sample=True, |
|
eos_token_id=2, |
|
stopping_criteria=stopping_criteria |
|
) |
|
return tokenizer.decode(gened[0][1:]).replace(prompt+" ", "").replace("</λ>", "") |
|
|
|
|
|
print(gen(lan="en", x="Hello, world!")) |
|
``` |