File size: 2,736 Bytes
0955f14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

from typing import Optional


# There are a couple non optimal parts of this code:
# 1. It doesn't inherit the Player class in main.py, which throws type checking errors
# 2. get_move_from_response() is duplicated from main.py
# However, I didn't want to add clutter and major dependencies like torch, peft, and transformers
# to those not using this class. So, this was my compromise.
class BaseLlamaPlayer:
    def __init__(
        self, tokenizer: AutoTokenizer, model: AutoModelForCausalLM, model_name: str
    ):
        self.tokenizer = tokenizer
        self.model = model
        self.model_name = model_name

    def get_llama_response(self, game_state: str, temperature: float) -> Optional[str]:
        prompt = game_state
        tokenized_input = self.tokenizer(prompt, return_tensors="pt").to("cuda")
        result = self.model.generate(
            **tokenized_input, max_new_tokens=10, temperature=temperature
        ).to("cpu")
        input_ids_tensor = tokenized_input["input_ids"]
        # transformers generate() returns <s> + prompt + output. This grabs only the output
        res_sliced = result[:, input_ids_tensor.shape[1] :]
        return self.tokenizer.batch_decode(res_sliced)[0]

    def get_move_from_response(self, response: Optional[str]) -> Optional[str]:
        if response is None:
            return None

        # Parse the response to get only the first move
        moves = response.split()
        first_move = moves[0] if moves else None

        return first_move

    def get_move(
        self, board: str, game_state: str, temperature: float
    ) -> Optional[str]:
        completion = self.get_llama_response(game_state, temperature)
        return self.get_move_from_response(completion)

    def get_config(self) -> dict:
        return {"model": self.model_name}


class LocalLlamaPlayer(BaseLlamaPlayer):
    def __init__(self, model_name: str):
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, device_map=0
        ).to("cuda")
        super().__init__(tokenizer, model, model_name)


class LocalLoraLlamaPlayer(BaseLlamaPlayer):
    def __init__(self, base_model_id: str, adapter_model_path: str):
        tokenizer = AutoTokenizer.from_pretrained(base_model_id)
        base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
        model = (
            PeftModel.from_pretrained(base_model, adapter_model_path)
            .merge_and_unload()
            .to("cuda")
        )

        super().__init__(tokenizer, model, adapter_model_path)