File size: 5,969 Bytes
2047d88
 
 
d2fde25
2047d88
d2fde25
49b2bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2047d88
 
 
 
 
49b2bf5
2047d88
 
 
49b2bf5
 
 
 
 
 
 
 
 
2047d88
474ebab
49b2bf5
247aecf
49b2bf5
247aecf
474ebab
49b2bf5
 
141eb85
49b2bf5
 
2047d88
49b2bf5
2047d88
 
 
 
49b2bf5
2047d88
 
 
 
 
ea9f47a
837ecd8
2047d88
 
837ecd8
141eb85
2047d88
 
 
ea9f47a
d2fde25
ea9f47a
9d37c49
 
247aecf
837ecd8
247aecf
 
9d37c49
 
 
 
 
ea9f47a
9d37c49
ea9f47a
474ebab
ea9f47a
9d37c49
ea9f47a
 
474ebab
 
 
 
d2fde25
474ebab
ea9f47a
 
 
2047d88
 
 
 
 
 
 
141eb85
2047d88
 
 
247aecf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import torch
import torch.nn as nn
import time

# Define the custom model class with detailed layer structures
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(64, head_size, bias=False)
        self.query = nn.Linear(64, head_size, bias=False)
        self.value = nn.Linear(64, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(32, 32)))
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = nn.functional.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(64, 64)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(0.1),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(61, 64)
        self.position_embedding_table = nn.Embedding(32, 64)
        self.blocks = nn.Sequential(*[Block(64, n_head=4) for _ in range(4)])
        self.ln_f = nn.LayerNorm(64)
        self.lm_head = nn.Linear(64, 61)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits, None

    def generate(self, idx, max_new_tokens, temperature=0.7):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -32:]  # Truncate to the latest 32 tokens
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]  # Get the logits for the last token
            logits = logits / temperature  # Apply temperature control
            probs = nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx_next = torch.clamp(idx_next, min=0, max=60)  # Strictly enforce index range [0, 60]
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# Load the model with strict=False to handle missing or unexpected keys
def load_model():
    model = BigramLanguageModel()
    model_url = "https://huggingface.co/yoonusajwardapiit/triptuner/resolve/main/pytorch_model.bin"
    model_weights = torch.hub.load_state_dict_from_url(model_url, map_location=torch.device('cpu'), weights_only=True)
    model.load_state_dict(model_weights, strict=False)
    model.eval()
    return model

model = load_model()

# Define a comprehensive character set based on training data
chars = sorted(list(set("abcdefghijklmnopqrstuvwxyz0123456789 .,!?-:;'\"\n")))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi.get(c, stoi.get(c.lower(), -1)) for c in s if c in stoi or c.lower() in stoi]  # Handles both cases
decode = lambda l: ''.join([itos[i] for i in l if i < len(itos)])  # Ensures index is within bounds

# Function to generate text using the model
def generate_text(prompt):
    try:
        start_time = time.time()
        print(f"Received prompt: {prompt}")
        encoded_prompt = encode(prompt)
        
        # Check for out-of-vocabulary indices
        if any(idx == -1 for idx in encoded_prompt):
            return "Error: Input contains characters not in the model vocabulary."

        # Ensure the prompt length fits within the block size
        if len(encoded_prompt) > 32:
            encoded_prompt = encoded_prompt[:32]  # Truncate to fit block size
        
        context = torch.tensor([encoded_prompt], dtype=torch.long)
        print(f"Encoded prompt: {context}")
        
        with torch.no_grad():
            generated = model.generate(context, max_new_tokens=20, temperature=0.7)  # Adjust temperature
            print(f"Generated tensor: {generated}")
        
        result = decode(generated[0].tolist())
        print(f"Decoded result: {result}")

        # Post-process to clean up and make output more readable
        cleaned_result = result.replace('\n', ' ').strip()
        print(f"Cleaned result: {cleaned_result}")
        print(f"Processing time: {time.time() - start_time:.2f}s")
        return cleaned_result
    except Exception as e:
        print(f"Error during generation: {e}")
        return f"Error: {str(e)}"

# Create a Gradio interface
interface = gr.Interface(
    fn=generate_text,
    inputs=gr.Textbox(lines=2, placeholder="Enter a location or prompt..."),
    outputs="text",
    title="Triptuner Model",
    description="Generate itineraries for locations in Sri Lanka's Central Province."
)

# Launch the interface
interface.launch()