codewithdark commited on
Commit
f213378
·
verified ·
1 Parent(s): 2fc2930

Create modeling_DiffusionLLM.py

Browse files
Files changed (1) hide show
  1. modeling_DiffusionLLM.py +126 -0
modeling_DiffusionLLM.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PretrainedConfig, PreTrainedModel
4
+
5
+ class DiffusionConfig(PretrainedConfig):
6
+ """Configuration class for Diffusion-LLM model."""
7
+ model_type = "diffusionLM"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size: int = 50257,
12
+ hidden_size: int = 768,
13
+ num_hidden_layers: int = 12,
14
+ num_attention_heads: int = 12,
15
+ intermediate_size: int = 3072,
16
+ hidden_dropout_prob: float = 0.1,
17
+ attention_probs_dropout_prob: float = 0.1,
18
+ max_position_embeddings: int = 1024,
19
+ initializer_range: float = 0.02,
20
+ layer_norm_eps: float = 1e-12,
21
+ pad_token_id: int = 0,
22
+ mask_token_id: int = 50256,
23
+ eos_token_id: int = 50256,
24
+ num_timesteps: int = 100,
25
+ time_embed_dim: int = 128,
26
+ **kwargs
27
+ ):
28
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_attention_heads = num_attention_heads
33
+ self.intermediate_size = intermediate_size
34
+ self.hidden_dropout_prob = hidden_dropout_prob
35
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
36
+ self.max_position_embeddings = max_position_embeddings
37
+ self.initializer_range = initializer_range
38
+ self.layer_norm_eps = layer_norm_eps
39
+ self.mask_token_id = mask_token_id
40
+ self.eos_token_id = eos_token_id
41
+ self.num_timesteps = num_timesteps
42
+ self.time_embed_dim = time_embed_dim
43
+
44
+ class DiffusionLLM(PreTrainedModel):
45
+ """Main Diffusion-LLM model class"""
46
+ config_class = DiffusionConfig
47
+ base_model_prefix = "diffusionLM"
48
+
49
+ def __init__(self, config: DiffusionConfig):
50
+ super().__init__(config)
51
+ self.model = LLaDAModel(config)
52
+ self.init_weights()
53
+
54
+ def forward(
55
+ self,
56
+ input_ids=None,
57
+ attention_mask=None,
58
+ timesteps=None,
59
+ labels=None,
60
+ return_dict=True,
61
+ ):
62
+ outputs = self.model(
63
+ input_ids=input_ids,
64
+ attention_mask=attention_mask,
65
+ timesteps=timesteps,
66
+ labels=labels,
67
+ )
68
+
69
+ return outputs
70
+
71
+ def generate(
72
+ self,
73
+ prompt=None,
74
+ max_length=100,
75
+ num_inference_steps=50,
76
+ temperature=1.0,
77
+ strategy='random',
78
+ top_p=0.9,
79
+ top_k=50,
80
+ num_beams=5,
81
+ return_scores=False,
82
+ use_streaming=False,
83
+ callback_fn=None
84
+ ):
85
+ """Unified generation interface"""
86
+ if use_streaming:
87
+ return self.generate_stream(
88
+ prompt=prompt,
89
+ max_length=max_length,
90
+ num_inference_steps=num_inference_steps,
91
+ temperature=temperature,
92
+ strategy=strategy,
93
+ top_p=top_p,
94
+ top_k=top_k,
95
+ num_beams=num_beams,
96
+ callback_fn=callback_fn
97
+ )
98
+ else:
99
+ return self.model.generate(
100
+ prompt=prompt,
101
+ max_length=max_length,
102
+ num_inference_steps=num_inference_steps,
103
+ temperature=temperature,
104
+ strategy=strategy,
105
+ top_p=top_p,
106
+ top_k=top_k,
107
+ num_beams=num_beams,
108
+ return_scores=return_scores
109
+ )
110
+
111
+ def generate_stream(self, **kwargs):
112
+ """Streaming generation wrapper"""
113
+ return self.model.generate_stream(**kwargs)
114
+
115
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
116
+ """Prepare inputs for generation compatibility"""
117
+ return {
118
+ "input_ids": input_ids,
119
+ "attention_mask": kwargs.get("attention_mask", None),
120
+ "timesteps": kwargs.get("timesteps", None),
121
+ }
122
+
123
+ @staticmethod
124
+ def _reorder_cache(past, beam_idx):
125
+ """Reorder cache for beam search compatibility"""
126
+ return past