diabolic6045 commited on
Commit
4b0778b
·
verified ·
1 Parent(s): 8bd4499

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +101 -0
inference.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, PretrainedConfig
3
+ from utils import load_config
4
+ from tokenization import get_tokenizer
5
+
6
+ class CustomConfig(PretrainedConfig):
7
+ """Configuration class for the custom language model."""
8
+ model_type = "custom_llm"
9
+
10
+ def __init__(
11
+ self,
12
+ vocab_size: int = 50000,
13
+ n_embd: int = 640,
14
+ n_head: int = 10,
15
+ n_layer: int = 12,
16
+ n_positions: int = 512,
17
+ tie_word_embeddings: bool = True,
18
+ **kwargs
19
+ ):
20
+ self.vocab_size = vocab_size
21
+ self.n_embd = n_embd
22
+ self.n_head = n_head
23
+ self.n_layer = n_layer
24
+ self.n_positions = n_positions
25
+ self.tie_word_embeddings = tie_word_embeddings
26
+ super().__init__(**kwargs)
27
+
28
+ def generate_text(
29
+ prompt: str,
30
+ model_path: str = "outputs/hf_model",
31
+ max_length: int = 200,
32
+ temperature: float = 0.8,
33
+ top_k: int = 50,
34
+ top_p: float = 0.9,
35
+ repetition_penalty: float = 1.2,
36
+ no_repeat_ngram_size: int = 3
37
+ ):
38
+ """Generate text using the model."""
39
+ # Load config and tokenizer
40
+ config = load_config()
41
+ tokenizer = get_tokenizer(config)
42
+
43
+ # Load model
44
+ from inference import CustomModelForCausalLM # Import here to avoid circular imports
45
+ model = CustomModelForCausalLM.from_pretrained(model_path)
46
+
47
+ # Move model to GPU if available
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model = model.to(device)
50
+ model.eval()
51
+
52
+ # Encode prompt
53
+ encoded = tokenizer.batch_encode(
54
+ [prompt],
55
+ return_tensors="pt"
56
+ )
57
+ input_ids = encoded["input_ids"].to(device)
58
+
59
+ # Generate
60
+ with torch.no_grad():
61
+ output_ids = model.generate(
62
+ input_ids=input_ids,
63
+ max_length=max_length,
64
+ temperature=temperature,
65
+ top_k=top_k,
66
+ top_p=top_p,
67
+ repetition_penalty=repetition_penalty,
68
+ no_repeat_ngram_size=no_repeat_ngram_size
69
+ )
70
+
71
+ # Decode and return
72
+ generated_text = tokenizer.decode(output_ids[0].tolist())
73
+ return generated_text
74
+
75
+ if __name__ == "__main__":
76
+ # Example prompts to test
77
+ prompts = [
78
+ "Once upon a time",
79
+ "The meaning of life is",
80
+ "In the distant future",
81
+ "The best way to learn programming is",
82
+ "Today I learned that"
83
+ ]
84
+
85
+ print("\nGenerating text from multiple prompts:")
86
+ print("=" * 50)
87
+
88
+ for prompt in prompts:
89
+ generated_text = generate_text(
90
+ prompt=prompt,
91
+ max_length=200,
92
+ temperature=0.8, # Adjust for creativity (higher = more creative)
93
+ top_k=50, # Limit to top 50 tokens
94
+ top_p=0.9, # Nucleus sampling threshold
95
+ repetition_penalty=1.2, # Penalize repetition
96
+ no_repeat_ngram_size=3 # Prevent 3-gram repetition
97
+ )
98
+
99
+ print(f"\nPrompt: {prompt}")
100
+ print(f"Generated: {generated_text}")
101
+ print("-" * 50)