File size: 40,747 Bytes
e3247e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8504055
 
 
 
 
 
e3247e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8504055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd44c92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8504055
cd44c92
8504055
 
 
 
cd44c92
8504055
 
cd44c92
 
 
 
8504055
cd44c92
 
8504055
cd44c92
8504055
 
cd44c92
8504055
 
cd44c92
8504055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd44c92
8504055
 
 
 
 
 
 
cd44c92
8504055
cd44c92
8504055
e3247e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd44c92
e3247e3
 
cd44c92
e3247e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8504055
 
cd44c92
 
 
 
e3247e3
 
8504055
 
e3247e3
8504055
 
e3247e3
 
 
 
cd44c92
 
 
 
 
 
 
 
 
8504055
e3247e3
 
8504055
 
 
e3247e3
cd44c92
 
 
 
 
 
 
e3247e3
 
cd44c92
 
e3247e3
cd44c92
 
 
 
 
e3247e3
cd44c92
 
 
 
 
 
 
 
 
 
 
 
8504055
cd44c92
 
e3247e3
 
 
 
 
 
 
 
 
 
 
 
 
 
cd44c92
 
 
 
 
e3247e3
cd44c92
 
 
 
 
 
 
 
 
8504055
cd44c92
 
 
8504055
cd44c92
 
8504055
 
 
 
 
e3247e3
 
d5947d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
import os
import torch
import json
import argparse
import numpy as np
import re
from torch import nn
from torch.nn import functional as F
import sentencepiece as spm
import math
from safetensors.torch import save_file, load_file
from tqdm import tqdm
import faiss
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS as LangchainFAISS
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from typing import List, Dict, Any, Optional, Callable
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import gc
import warnings

# Ignore specific HuggingFace warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*The model doesn't have tied token embeddings.*")

# Tokenizer wrapper class - same as in original code
class SentencePieceTokenizerWrapper:
    def __init__(self, sp_model_path):
        self.sp_model = spm.SentencePieceProcessor()
        self.sp_model.Load(sp_model_path)
        self.vocab_size = self.sp_model.GetPieceSize()
        
        # Special token IDs from tokenizer training
        self.pad_token_id = 0
        self.bos_token_id = 1
        self.eos_token_id = 2
        self.unk_token_id = 3
        
        # Set special tokens
        self.pad_token = "<pad>"
        self.bos_token = "<s>"
        self.eos_token = "</s>"
        self.unk_token = "<unk>"
        self.mask_token = "<mask>"
    
    def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
        # Handle both string and list inputs
        if isinstance(text, str):
            # Encode a single string
            ids = self.sp_model.EncodeAsIds(text)
            
            # Handle truncation
            if truncation and max_length and len(ids) > max_length:
                ids = ids[:max_length]
                
            attention_mask = [1] * len(ids)
            
            # Handle padding
            if padding and max_length:
                padding_length = max(0, max_length - len(ids))
                ids = ids + [self.pad_token_id] * padding_length
                attention_mask = attention_mask + [0] * padding_length
            
            result = {
                'input_ids': ids,
                'attention_mask': attention_mask
            }
            
            # Convert to tensors if requested
            if return_tensors == 'pt':
                import torch
                result = {k: torch.tensor([v]) for k, v in result.items()}
            
            return result
        
        # Process a batch of texts
        batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
        
        # Apply truncation if needed
        if truncation and max_length:
            batch_encoded = [ids[:max_length] for ids in batch_encoded]
        
        # Create attention masks
        batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
        
        # Apply padding if needed
        if padding:
            if max_length:
                max_len = max_length
            else:
                max_len = max(len(ids) for ids in batch_encoded)
            
            # Pad sequences to max_len
            batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
            batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
        
        result = {
            'input_ids': batch_encoded,
            'attention_mask': batch_attention_mask
        }
        
        # Convert to tensors if requested
        if return_tensors == 'pt':
            import torch
            result = {k: torch.tensor(v) for k, v in result.items()}
        
        return result

# Model architecture definitions for inference

class MultiHeadAttention(nn.Module):
    """Advanced multi-headed attention with relative positional encoding"""
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config["num_attention_heads"]
        self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        
        # Query, Key, Value projections
        self.query = nn.Linear(config["hidden_size"], self.all_head_size)
        self.key = nn.Linear(config["hidden_size"], self.all_head_size)
        self.value = nn.Linear(config["hidden_size"], self.all_head_size)
        
        # Output projection
        self.output = nn.Sequential(
            nn.Linear(self.all_head_size, config["hidden_size"]),
            nn.Dropout(config["attention_probs_dropout_prob"])
        )
        
        # Simplified relative position bias approach
        self.max_position_embeddings = config["max_position_embeddings"]
        self.relative_attention_bias = nn.Embedding(
            2 * config["max_position_embeddings"] - 1, 
            config["num_attention_heads"]
        )
        
    def transpose_for_scores(self, x):
        new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_shape)
        return x.permute(0, 2, 1, 3)
    
    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length = hidden_states.size()[:2]
        
        # Project inputs to queries, keys, and values
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        
        # Take the dot product between query and key to get the raw attention scores
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        
        # Generate relative position matrix
        position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
        relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0)  # [seq_len, seq_len]
        # Shift values to be >= 0
        relative_position = relative_position + self.max_position_embeddings - 1  
        # Ensure indices are within bounds
        relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)  
        
        # Get relative position embeddings [seq_len, seq_len, num_heads]
        rel_attn_bias = self.relative_attention_bias(relative_position)  # [seq_len, seq_len, num_heads]
        
        # Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
        rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
        
        # Add to attention scores - now dimensions will match
        attention_scores = attention_scores + rel_attn_bias
        
        # Scale attention scores
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Apply attention mask
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask
        
        # Normalize the attention scores to probabilities
        attention_probs = F.softmax(attention_scores, dim=-1)
        
        # Apply dropout
        attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
        
        # Apply attention to values
        context_layer = torch.matmul(attention_probs, value_layer)
        
        # Reshape back to [batch_size, seq_length, hidden_size]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_shape)
        
        # Final output projection
        output = self.output(context_layer)
        
        return output

class EnhancedTransformerLayer(nn.Module):
    """Advanced transformer layer with pre-layer norm and enhanced attention"""
    def __init__(self, config):
        super().__init__()
        self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        self.attention = MultiHeadAttention(config)
        
        self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(config["hidden_size"], config["intermediate_size"]),
            nn.GELU(),
            nn.Dropout(config["hidden_dropout_prob"]),
            nn.Linear(config["intermediate_size"], config["hidden_size"]),
            nn.Dropout(config["hidden_dropout_prob"])
        )
        
    def forward(self, hidden_states, attention_mask=None):
        # Pre-layer norm for attention
        attn_norm_hidden = self.attention_pre_norm(hidden_states)
        
        # Self-attention
        attention_output = self.attention(attn_norm_hidden, attention_mask)
        
        # Residual connection
        hidden_states = hidden_states + attention_output
        
        # Pre-layer norm for feed-forward
        ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
        
        # Feed-forward
        ffn_output = self.ffn(ffn_norm_hidden)
        
        # Residual connection
        hidden_states = hidden_states + ffn_output
        
        return hidden_states

class AdvancedTransformerModel(nn.Module):
    """Advanced Transformer model for inference"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Embeddings
        self.word_embeddings = nn.Embedding(
            config["vocab_size"], 
            config["hidden_size"], 
            padding_idx=config["pad_token_id"]
        )
        
        # Position embeddings
        self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
        
        # Embedding dropout
        self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
        
        # Transformer layers
        self.layers = nn.ModuleList([
            EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
        ])
        
        # Final layer norm
        self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
        
    def forward(self, input_ids, attention_mask=None):
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        
        # Get position ids
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
        
        # Get embeddings
        word_embeds = self.word_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        
        # Sum embeddings
        embeddings = word_embeds + position_embeds
        
        # Apply dropout
        embeddings = self.embedding_dropout(embeddings)
        
        # Default attention mask
        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=input_ids.device)
        
        # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # Apply transformer layers
        hidden_states = embeddings
        for layer in self.layers:
            hidden_states = layer(hidden_states, extended_attention_mask)
        
        # Final layer norm
        hidden_states = self.final_layer_norm(hidden_states)
        
        return hidden_states

class AdvancedPooling(nn.Module):
    """Advanced pooling module supporting multiple pooling strategies"""
    def __init__(self, config):
        super().__init__()
        self.pooling_mode = config["pooling_mode"]  # 'mean', 'max', 'cls', 'attention'
        self.hidden_size = config["hidden_size"]
        
        # For attention pooling
        if self.pooling_mode == 'attention':
            self.attention_weights = nn.Linear(config["hidden_size"], 1)
            
        # For weighted pooling
        elif self.pooling_mode == 'weighted':
            self.weight_layer = nn.Linear(config["hidden_size"], 1)
            
    def forward(self, token_embeddings, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(token_embeddings[:, :, 0])
            
        mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        
        if self.pooling_mode == 'cls':
            # Use [CLS] token (first token)
            pooled = token_embeddings[:, 0]
            
        elif self.pooling_mode == 'max':
            # Max pooling
            token_embeddings = token_embeddings.clone()
            # Set padding tokens to large negative value to exclude them from max
            token_embeddings[mask_expanded == 0] = -1e9
            pooled = torch.max(token_embeddings, dim=1)[0]
            
        elif self.pooling_mode == 'attention':
            # Attention pooling
            weights = self.attention_weights(token_embeddings).squeeze(-1)
            # Mask out padding tokens
            weights = weights.masked_fill(attention_mask == 0, -1e9)
            weights = F.softmax(weights, dim=1).unsqueeze(-1)
            pooled = torch.sum(token_embeddings * weights, dim=1)
            
        elif self.pooling_mode == 'weighted':
            # Weighted average pooling
            weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
            # Apply mask
            weights = weights * attention_mask
            # Normalize weights
            sum_weights = torch.sum(weights, dim=1, keepdim=True)
            sum_weights = torch.clamp(sum_weights, min=1e-9)
            weights = weights / sum_weights
            # Apply weights
            pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
            
        else:  # Default to mean pooling
            # Mean pooling
            sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
            sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
            pooled = sum_embeddings / sum_mask
            
        # L2 normalize
        pooled = F.normalize(pooled, p=2, dim=1)
        
        return pooled

class SentenceEmbeddingModel(nn.Module):
    """Complete sentence embedding model for inference"""
    def __init__(self, config):
        super(SentenceEmbeddingModel, self).__init__()
        self.config = config
        
        # Create transformer model
        self.transformer = AdvancedTransformerModel(config)
        
        # Create pooling module
        self.pooling = AdvancedPooling(config)
        
        # Build projection module if needed
        if "projection_dim" in config and config["projection_dim"] > 0:
            self.use_projection = True
            self.projection = nn.Sequential(
                nn.Linear(config["hidden_size"], config["hidden_size"]),
                nn.GELU(),
                nn.Linear(config["hidden_size"], config["projection_dim"]),
                nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
            )
        else:
            self.use_projection = False
            
    def forward(self, input_ids, attention_mask=None):
        # Get token embeddings from transformer
        token_embeddings = self.transformer(input_ids, attention_mask)
        
        # Pool token embeddings
        pooled_output = self.pooling(token_embeddings, attention_mask)
        
        # Apply projection if enabled
        if self.use_projection:
            pooled_output = self.projection(pooled_output)
            pooled_output = F.normalize(pooled_output, p=2, dim=1)
        
        return pooled_output

def convert_to_safetensors(model_path, output_path):
    """Convert PyTorch model to safetensors format"""
    print(f"Converting model from {model_path} to safetensors format...")
    
    try:
        # First try with weights_only=False to handle PyTorch 2.6+ checkpoints
        checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
        print("Successfully loaded checkpoint with weights_only=False")
    except TypeError:
        # For older PyTorch versions that don't have weights_only parameter
        print("Falling back to default torch.load behavior for older PyTorch versions")
        checkpoint = torch.load(model_path, map_location="cpu")
    
    # Get model state dict
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
        print("Extracted model_state_dict from checkpoint")
    else:
        state_dict = checkpoint
        print("Using entire checkpoint as state_dict")
    
    # Save as safetensors
    save_file(state_dict, output_path)
    print(f"Model converted and saved to {output_path}")

def load_model_and_tokenizer(model_dir, tokenizer_dir="/home/ubuntu/hindi_tokenizer"):
    """Load the model and tokenizer for inference"""
    
    # Load the config
    config_path = os.path.join(model_dir, "config.json")
    with open(config_path, "r") as f:
        config = json.load(f)
    
    # Load the tokenizer - use specified tokenizer directory
    tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.model")
    if not os.path.exists(tokenizer_path):
        # Try other locations
        tokenizer_path = os.path.join(model_dir, "tokenizer.model")
        if not os.path.exists(tokenizer_path):
            raise FileNotFoundError(f"Could not find tokenizer model at {tokenizer_path}")
    
    tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
    print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {tokenizer.vocab_size}")
    
    # Load the model
    safetensors_path = os.path.join(model_dir, "embedding_model.safetensors")
    
    if not os.path.exists(safetensors_path):
        print(f"Safetensors model not found at {safetensors_path}, converting from PyTorch checkpoint...")
        
        # Convert from PyTorch checkpoint
        pytorch_path = os.path.join(model_dir, "embedding_model.pt")
        if not os.path.exists(pytorch_path):
            raise FileNotFoundError(f"Could not find PyTorch model at {pytorch_path}")
        
        convert_to_safetensors(pytorch_path, safetensors_path)
    
    # Load state dict from safetensors
    state_dict = load_file(safetensors_path)
    
    # Create model
    model = SentenceEmbeddingModel(config)
    
    # Load state dict
    try:
        # Try direct loading
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        print(f"Loaded model with missing keys: {missing_keys[:10]}{'...' if len(missing_keys) > 10 else ''}")
        print(f"Unexpected keys: {unexpected_keys[:10]}{'...' if len(unexpected_keys) > 10 else ''}")
    except Exception as e:
        print(f"Error loading state dict: {e}")
        print("Model will be initialized with random weights")
    
    model.eval()
    
    return model, tokenizer, config

# LangChain Custom Embeddings Class
class HindiSentenceEmbeddings(Embeddings):
    """
    Custom Langchain Embeddings class for Hindi sentence embeddings model
    """
    def __init__(self, model, tokenizer, device="cuda", batch_size=32, max_length=128):
        """Initialize with model, tokenizer, and inference parameters"""
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.batch_size = batch_size
        self.max_length = max_length
        
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed a list of documents/texts"""
        embeddings = []
        
        with torch.no_grad():
            for i in range(0, len(texts), self.batch_size):
                batch = texts[i:i+self.batch_size]
                
                # Tokenize
                inputs = self.tokenizer(
                    batch, 
                    padding="max_length", 
                    truncation=True, 
                    max_length=self.max_length, 
                    return_tensors="pt"
                )
                
                # Move to device
                input_ids = inputs["input_ids"].to(self.device)
                attention_mask = inputs["attention_mask"].to(self.device)
                
                # Get embeddings
                batch_embeddings = self.model(input_ids, attention_mask)
                
                # Move to CPU and convert to numpy
                batch_embeddings = batch_embeddings.cpu().numpy()
                embeddings.append(batch_embeddings)
        
        return np.vstack(embeddings).tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """Embed a single query/text"""
        return self.embed_documents([text])[0]

def extract_relevant_sentences(text, query, window_size=2):
    """
    Extract the most relevant sentences from text based on query keywords
    
    Args:
        text: The full text content
        query: The user's query
        window_size: Number of sentences to include before and after matched sentence
        
    Returns:
        String containing the most relevant portion of the text
    """
    # Clean and normalize query and text for matching
    query = query.strip().lower()
    
    # Remove question marks and other punctuation from query for matching
    query = re.sub(r'[?।॥!,.:]', '', query)
    
    # Extract keywords from the query (remove common Hindi stop words)
    stop_words = ['और', 'का', 'के', 'को', 'में', 'से', 'है', 'हैं', 'था', 'थे', 'की', 'कि', 'पर', 'एक', 'यह', 'वह', 'जो', 'ने', 'हो', 'कर']
    query_terms = [word for word in query.split() if word not in stop_words]
    
    if not query_terms:
        return text  # If no meaningful terms left, return the full text
    
    # Split text into sentences (using Hindi sentence terminators)
    sentences = re.split(r'([।॥!?.])', text)
    
    # Rejoin sentences with their terminators
    complete_sentences = []
    for i in range(0, len(sentences)-1, 2):
        if i+1 < len(sentences):
            complete_sentences.append(sentences[i] + sentences[i+1])
        else:
            complete_sentences.append(sentences[i])
    
    # If the above didn't work properly, try simpler approach
    if len(complete_sentences) <= 1:
        complete_sentences = re.split(r'[।॥!?.]', text)
        complete_sentences = [s.strip() for s in complete_sentences if s.strip()]
    
    # Score each sentence based on how many query terms it contains
    sentence_scores = []
    for i, sentence in enumerate(complete_sentences):
        sentence_lower = sentence.lower()
        # Calculate score based on number of query terms found
        score = sum(1 for term in query_terms if term in sentence_lower)
        sentence_scores.append((i, score))
    
    # Find the best matching sentence
    if not sentence_scores:
        return text[:500] + "..."  # Fallback
    
    # Get the index of sentence with highest score
    best_match_idx, best_score = max(sentence_scores, key=lambda x: x[1])
    
    # If no good match found, return the whole text (up to a limit)
    if best_score == 0:
        # Try partial word matching as a fallback
        for i, sentence in enumerate(complete_sentences):
            sentence_lower = sentence.lower()
            partial_score = sum(1 for term in query_terms if any(term in word.lower() for word in sentence_lower.split()))
            if partial_score > 0:
                best_match_idx = i
                break
        else:
            # If still no match, just return the first part of the text
            if len(text) > 1000:
                return text[:1000] + "..."
            return text
    
    # Get window of sentences around the best match
    start_idx = max(0, best_match_idx - window_size)
    end_idx = min(len(complete_sentences), best_match_idx + window_size + 1)
    
    # Create excerpt
    relevant_text = ' '.join(complete_sentences[start_idx:end_idx])
    
    # If the excerpt is short, return more context
    if len(relevant_text) < 100 and len(text) > len(relevant_text):
        # Add more context
        if end_idx < len(complete_sentences):
            relevant_text += ' ' + ' '.join(complete_sentences[end_idx:end_idx+2])
        if start_idx > 0:
            relevant_text = ' '.join(complete_sentences[max(0, start_idx-2):start_idx]) + ' ' + relevant_text
    
    # If the excerpt is too short or the whole text is small anyway, return whole text
    if len(relevant_text) < 50 or len(text) < 1000:
        return text
    
    return relevant_text

# Text processing and indexing functions
def load_and_process_text_file(file_path, chunk_size=500, chunk_overlap=100):
    """
    Load a text file and split it into semantically meaningful chunks
    """
    print(f"Loading and processing text file: {file_path}")
    
    # Read the file content
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # For small files, just keep the whole content as a single chunk
    if len(content) <= chunk_size * 2:
        print(f"File content is small, keeping as a single chunk")
        return [Document(
            page_content=content,
            metadata={
                "source": file_path,
                "chunk_id": 0
            }
        )]
    
    # Split by paragraphs first
    paragraphs = re.split(r'\n\s*\n', content)
    chunks = []
    
    current_chunk = ""
    current_size = 0
    
    for para in paragraphs:
        if not para.strip():
            continue
            
        # If adding this paragraph would exceed the chunk size, save current chunk and start new one
        if current_size + len(para) > chunk_size and current_size > 0:
            chunks.append(current_chunk)
            current_chunk = para
            current_size = len(para)
        else:
            # Add paragraph to current chunk with a newline if not empty
            if current_size > 0:
                current_chunk += "\n\n" + para
            else:
                current_chunk = para
            current_size = len(current_chunk)
    
    # Add the last chunk if not empty
    if current_chunk:
        chunks.append(current_chunk)
    
    print(f"Split text into {len(chunks)} chunks")
    
    # Convert to LangChain documents with metadata
    documents = [
        Document(
            page_content=chunk,
            metadata={
                "source": file_path,
                "chunk_id": i
            }
        ) for i, chunk in enumerate(chunks)
    ]
    
    return documents

def create_vector_store(documents, embeddings, store_path=None):
    """
    Create a FAISS vector store from documents using the given embeddings
    """
    print("Creating FAISS vector store...")
    
    # Create vector store
    vector_store = LangchainFAISS.from_documents(documents, embeddings)
    
    # Save if path is provided
    if store_path:
        print(f"Saving vector store to {store_path}")
        vector_store.save_local(store_path)
    
    return vector_store

def load_vector_store(store_path, embeddings):
    """
    Load a FAISS vector store from disk
    """
    print(f"Loading vector store from {store_path}")
    return LangchainFAISS.load_local(store_path, embeddings, allow_dangerous_deserialization=True)

def perform_similarity_search(vector_store, query, k=6):
    """
    Perform basic similarity search on the vector store
    """
    print(f"Searching for: {query}")
    return vector_store.similarity_search_with_score(query, k=k)

# Llama model loading function
def load_llama_model(model_name="unsloth/Llama-3.2-1B-Instruct", device="cuda"):
    """
    Load and prepare Llama model for text generation
    """
    print(f"Loading LLM: {model_name}")
    
    # Check if CUDA is available
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        device = "cpu"

    # Quantization config for 4-bit precision to save memory
    quantization = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    ) if device == "cuda" else None

    # Standard HuggingFace loading
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if device == "cuda":
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            quantization_config=quantization
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name)
        model = model.to(device)
    
    print("Successfully loaded model")
    
    return model, tokenizer

# NEW FUNCTIONS FOR COMBINED RESULTS APPROACH

def combine_top_results(results, query, max_results=4):
    """
    Combine the top search results into a single coherent context
    
    Args:
        results: List of (Document, score) tuples from retrieval
        query: Original user query
        max_results: Maximum number of results to combine
        
    Returns:
        String containing combined context from top results
    """
    # Sort results by score (highest first) and take top N
    sorted_results = sorted(results, key=lambda x: x[1], reverse=True)[:max_results]
    
    combined_texts = []
    seen_content = set()  # To avoid duplicates
    
    for doc, score in sorted_results:
        # Extract relevant sentences to keep context focused
        relevant_text = extract_relevant_sentences(doc.page_content, query, window_size=3)
        
        # Skip if this exact text has been seen before
        if relevant_text in seen_content:
            continue
            
        # Add source information to the text
        source_name = os.path.basename(doc.metadata["source"])
        text_with_source = f"{relevant_text} [Source: {source_name}]"
        
        combined_texts.append(text_with_source)
        seen_content.add(relevant_text)
    
    # Combine all texts with clear separation
    combined_context = "\n\n".join(combined_texts)
    
    return combined_context

def setup_enhanced_qa_system(model, tokenizer, vector_store):
    """
    Set up an enhanced QA system using the model and retriever with result combination
    """
    # Create retriever
    retriever = vector_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": 6}  # Get more results than we'll use to filter better
    )
    
    # Create a function to generate answers with combined context
    def generate_enhanced_answer(query):
        # Get raw documents and scores
        docs = vector_store.similarity_search_with_score(query, k=6)
        
        # Combine the top results into a single context
        combined_context = combine_top_results(docs, query, max_results=4)
        
        # Create prompt with the combined context
        prompt = f"""
आपको निम्नलिखित संदर्भ से जानकारी के आधार पर एक प्रश्न का उत्तर देना है। 
यदि आप उत्तर नहीं जानते हैं, तो बस "मुझे नहीं पता" कहें। अपने उत्तर में सभी प्रासंगिक जानकारी का उपयोग करें।

संदर्भ:
{combined_context}

प्रश्न: {query}

उत्तर:
"""
        
        # Generate text
        inputs = tokenizer(prompt, return_tensors="pt")
        
        # Move to the same device as the model
        for k, v in inputs.items():
            if hasattr(v, "to") and callable(v.to):
                inputs[k] = v.to(model.device)
        
        with torch.no_grad():
            try:
                outputs = model.generate(
                    inputs.input_ids,
                    max_new_tokens=512,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True
                )
            except Exception as e:
                return f"Error generating response: {str(e)}", None
        
        # Decode the generated text
        full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract just the answer part (after the prompt)
        answer = full_response.split("उत्तर:")[-1].strip()
        
        return answer, combined_context
    
    return generate_enhanced_answer

# Main RAG functions
def index_text_files(model, tokenizer, data_dir, output_dir, device="cuda", chunk_size=500):
    """
    Index text files from a directory and create a FAISS vector store
    """
    print(f"Indexing text files from {data_dir} with chunk size ({chunk_size}) for fine-grained retrieval")
    
    # Create embedding model
    embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Get all text files
    text_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.txt')]
    print(f"Found {len(text_files)} text files")
    
    # Process all text files
    all_documents = []
    for file_path in text_files:
        documents = load_and_process_text_file(file_path, chunk_size=chunk_size)
        all_documents.extend(documents)
    
    print(f"Total documents: {len(all_documents)}")
    
    # If we don't have enough chunks, reduce chunk size and try again
    if len(all_documents) < 10 and chunk_size > 50:
        print(f"Not enough chunks created. Reducing chunk size and trying again...")
        return index_text_files(model, tokenizer, data_dir, output_dir, device, chunk_size=chunk_size//2)
    
    # Create and save vector store
    vector_store_path = os.path.join(output_dir, "faiss_index")
    vector_store = create_vector_store(all_documents, embeddings, vector_store_path)
    
    return vector_store, embeddings

def query_text_corpus(model, tokenizer, vector_store_path, query, k=6, device="cuda"):
    """
    Query the text corpus using the indexed vector store
    """
    # Create embedding model
    embeddings = HindiSentenceEmbeddings(model, tokenizer, device=device)
    
    # Load vector store
    vector_store = load_vector_store(vector_store_path, embeddings)
    
    # Perform similarity search
    results = perform_similarity_search(vector_store, query, k=k)
    
    return results, vector_store

def main():
    parser = argparse.ArgumentParser(description="Hindi RAG System with Combined Results")
    parser.add_argument("--model_dir", type=str, default="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final",
                        help="Directory containing the model and tokenizer")
    parser.add_argument("--tokenizer_dir", type=str, default="/home/ubuntu/hindi_tokenizer",
                        help="Directory containing the tokenizer")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to run inference on ('cuda' or 'cpu')")
    parser.add_argument("--index", action="store_true",
                        help="Index text files from data directory")
    parser.add_argument("--query", type=str, default=None,
                        help="Query to search in the indexed corpus")
    parser.add_argument("--data_dir", type=str, default="./data",
                        help="Directory containing text files for indexing")
    parser.add_argument("--output_dir", type=str, default="./output",
                        help="Directory to save the indexed vector store")
    parser.add_argument("--top_k", type=int, default=6,
                        help="Number of top results to return")
    parser.add_argument("--chunk_size", type=int, default=500,
                        help="Size of text chunks for indexing")
    parser.add_argument("--interactive", action="store_true",
                        help="Run in interactive mode for querying")
    parser.add_argument("--reindex", action="store_true",
                        help="Force reindexing even if index exists")
    parser.add_argument("--llm_name", type=str, default="unsloth/Llama-3.2-1B-Instruct",
                        help="HuggingFace model name for the LLM")
    parser.add_argument("--show_context", action="store_true",
                        help="Show the combined context sent to the LLM")
    parser.add_argument("--show_raw_results", action="store_true",
                        help="Show the raw search results before combination")
    args = parser.parse_args()
    
    # Load embedding model and tokenizer
    embed_model, embed_tokenizer, config = load_model_and_tokenizer(args.model_dir, args.tokenizer_dir)
    
    # Move embedding model to device
    embed_model = embed_model.to(args.device)
    
    # Create vector store path
    vector_store_path = os.path.join(args.output_dir, "faiss_index")
    
    # Load LLM
    try:
        # Load LLM
        llm_model, llm_tokenizer = load_llama_model(args.llm_name, args.device)
        print("LLM loaded successfully for QA")
    except Exception as e:
        print(f"Error loading LLM: {e}")
        print("Cannot proceed without LLM for this combined results approach")
        return
    
    if args.index or args.reindex:
        # Index text files
        vector_store, _ = index_text_files(
            embed_model, embed_tokenizer, args.data_dir, args.output_dir, args.device, args.chunk_size
        )
        print(f"Indexing complete. Vector store saved to {vector_store_path}")
    
    # Load vector store for querying
    embeddings = HindiSentenceEmbeddings(embed_model, embed_tokenizer, device=args.device)
    vector_store = load_vector_store(vector_store_path, embeddings)
    
    # Set up enhanced QA system
    qa_generator = setup_enhanced_qa_system(llm_model, llm_tokenizer, vector_store)
    
    if args.query:
        # Process the query with the enhanced system
        print(f"\nProcessing query: {args.query}")
        
        # Show raw results if requested
        if args.show_raw_results:
            results, _ = query_text_corpus(
                embed_model, embed_tokenizer, vector_store_path, args.query, args.top_k, args.device
            )
            
            print("\nRaw Search Results:")
            for i, (doc, score) in enumerate(results):
                print(f"\nResult {i+1} (Score: {score:.4f}):")
                print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
                print(f"Content: {doc.page_content[:200]}...")
        
        # Generate enhanced answer
        answer, context = qa_generator(args.query)
        
        if args.show_context:
            print("\nCombined Context:")
            print(context)
            
        print("\nEnhanced LLM Answer:")
        print(answer)
    
    if args.interactive:
        print("\nInteractive mode. Enter queries (or type 'quit' to exit).")
        
        while True:
            print("\nEnter query:")
            query = input()
            
            if not query.strip():
                continue
                
            if query.lower() == 'quit':
                break
            
            # Show raw results if requested
            if args.show_raw_results:
                results, _ = query_text_corpus(
                    embed_model, embed_tokenizer, vector_store_path, query, args.top_k, args.device
                )
                
                print("\nRaw Search Results:")
                for i, (doc, score) in enumerate(results):
                    print(f"\nResult {i+1} (Score: {score:.4f}):")
                    print(f"Source: {doc.metadata['source']}, Chunk: {doc.metadata['chunk_id']}")
                    print(f"Content: {doc.page_content[:200]}...")
            
            # Process the query
            print(f"\nProcessing query: {query}")
            answer, context = qa_generator(query)
            
            if args.show_context:
                print("\nCombined Context:")
                print(context)
                
            print("\nEnhanced LLM Answer:")
            print(answer)
    
    # Clean up GPU memory
    if args.device == "cuda":
        gc.collect()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()