File size: 2,503 Bytes
d94c1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from model import GPT, GPTConfig  # Import your original model and config classes
import json

class CustomGPTConfig(PretrainedConfig):
    model_type = "gpt"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for key, value in kwargs.items():
            setattr(self, key, value)

class MatterGPTWrapper(PreTrainedModel):
    config_class = CustomGPTConfig
    base_model_prefix = "gpt"

    def __init__(self, config):
        super().__init__(config)
        self.model = GPT(GPTConfig(**config.__dict__))

    def forward(self, input_ids, attention_mask=None, labels=None, prop=None):
        return self.model(input_ids, targets=labels, prop=prop)

    def generate(self, input_ids, prop, max_length, num_return_sequences=1, **kwargs):
        steps = max_length - input_ids.shape[1]
        return self.model.sample(input_ids, steps, prop=prop, **kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_path, *model_args, **kwargs):
        config_file = f"{pretrained_model_path}/config.json"
        with open(config_file, 'r') as f:
            config_dict = json.load(f)
        
        config = CustomGPTConfig(**config_dict)
        
        model = cls(config)

        
        # 加载模型权重
        state_dict = torch.load(f"{pretrained_model_path}/pytorch_model.pt", map_location="cpu")
        model.model.load_state_dict(state_dict)
        
        return model

    def save_pretrained(self, save_directory):
        self.config.save_pretrained(save_directory)
        torch.save(self.model.state_dict(), f"{save_directory}/pytorch_model.pt")

class SimpleTokenizer:
    def __init__(self, vocab_file):
        with open(vocab_file, 'r') as f:
            self.vocab = f.read().splitlines()
        self.vocab = sorted(set(self.vocab + ['<', '>']))
        self.stoi = {ch: i for i, ch in enumerate(self.vocab)}
        self.itos = {i: ch for i, ch in enumerate(self.vocab)}

    def encode(self, text):
        return [self.stoi[token] for token in text.split()]

    def decode(self, ids):
        return " ".join([self.itos[int(i)] for i in ids if i in self.itos]).replace("<", "").strip()

    def __call__(self, text, return_tensors=None):
        encoded = self.encode(text)
        if return_tensors == 'pt':
            import torch
            return {'input_ids': torch.tensor([encoded])}
        return {'input_ids': [encoded]}