File size: 3,903 Bytes
28741c4
4a5940d
23aac07
4a5940d
430bd60
0a82a06
 
 
33d2c2b
23aac07
33d2c2b
a7f8a50
4f7d264
c0e19b5
a7f8a50
4f7d264
23aac07
 
 
 
b3ebb95
23aac07
 
 
127bf34
 
b3ebb95
 
 
 
 
 
23aac07
 
c572e06
23aac07
c572e06
8ed9f61
23aac07
8ed9f61
 
c572e06
8ed9f61
 
 
 
c572e06
8ed9f61
 
 
23aac07
2f4df5a
8ed9f61
36f12a2
4a5940d
 
 
36f12a2
4a5940d
 
23aac07
36f12a2
4a5940d
 
36f12a2
430bd60
 
d83ea90
8ed9f61
23aac07
 
 
2f4df5a
8ed9f61
 
 
430bd60
 
 
 
23aac07
 
 
 
 
 
4a5940d
430bd60
23aac07
 
2f4df5a
 
 
23aac07
2f4df5a
bfd2d06
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Any, List, Dict
from llama_cpp import Llama
import numpy as np
import torch
from transformers import AutoTokenizer, LogitsProcessorList

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the model handler using llama_cpp.
        """
        self.model = Llama.from_pretrained(
            repo_id="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
            filename="Meta-Llama-3.1-8B-Instruct-Q6_K.gguf"
        )
        self.tokenizer = AutoTokenizer.from_pretrained("taylorj94/Llama-3.2-1B")

    def get_allowed_token_ids(self, vocab_list: List[str]) -> set[int]:
        """
        Generate a set of token IDs for a given list of allowed words.
        Includes plain, space-prefixed, capitalized, and uppercase forms of each word.
        """
        allowed_ids = set()
        for word in vocab_list:
            # Generate all variations: plain, space-prefixed, and capitalized
            variations = {word, " " + word, word.capitalize(), " " + word.capitalize()}
            
            # Add token IDs for all variations
            for variation in variations:
                for token_id in self.tokenizer.encode(variation, add_special_tokens=False):
                    allowed_ids.add(token_id)
                    
        return allowed_ids

    def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray:
        """
        Modify scores to allow only tokens in the allowed_token_ids set.
        Handles both 1D and 2D scores arrays.
        """
        if scores.ndim == 1:
            # 1D case: Apply mask directly
            mask = np.isin(np.arange(scores.shape[0]), list(allowed_token_ids))
            scores[~mask] = float('-inf')
        elif scores.ndim == 2:
            # 2D case: Apply mask across each row
            for i in range(scores.shape[0]):
                mask = np.isin(np.arange(scores.shape[1]), list(allowed_token_ids))
                scores[i, ~mask] = float('-inf')
        else:
            raise ValueError(f"Unsupported scores dimension: {scores.ndim}")
        return scores


    def __call__(self, data: Any) -> List[Dict[str, str]]:
        """
        Handle the request, performing inference with a restricted vocabulary.
        """
        # Extract inputs and parameters
        inputs = data.get("inputs", None)
        parameters = data.get("parameters", {})
        vocab_list = data.get("vocab_list", None)

        if not inputs:
            raise ValueError("The 'inputs' field is required.")

        # Prepare logits processor
        logits_processors = None
        allowed_token_ids = []
        
        if vocab_list:
            # Define allowed tokens dynamically
            allowed_token_ids = self.get_allowed_token_ids(vocab_list)

            # Tokenize input
            input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)])
            
            # Create LogitsProcessorList with filtering function
            logits_processors = LogitsProcessorList([
                lambda input_ids, scores: self.filter_allowed_tokens(input_ids, scores, allowed_token_ids)
            ])

        # Perform inference using the `create_chat_completion` method
        response = self.model.create_chat_completion(
            messages=[
                {"role": "user", "content": inputs}
            ],
            max_tokens=parameters.get("max_length", 30),
            logits_processor=logits_processors,  # Pass the LogitsProcessorList here
            temperature=parameters.get("temperature", 1),
            repeat_penalty=parameters.get("repeat_penalty", 1.0)
        )

        # Decode the output
        generated_text = response["choices"][0]["message"]["content"]

        return [{"generated_text": generated_text, "allowed_token_ids": list(allowed_token_ids)}]