diff --git "a/llmlingua/prompt_compressor.py" "b/llmlingua/prompt_compressor.py"
deleted file mode 100644--- "a/llmlingua/prompt_compressor.py"
+++ /dev/null
@@ -1,2412 +0,0 @@
-# Copyright (c) 2023 Microsoft
-# Licensed under The MIT License [see LICENSE for details]
-
-import bisect
-import re
-from collections import defaultdict
-from typing import List
-
-import numpy as np
-import torch
-
-import nltk
-import tiktoken
-from transformers import (
-    AutoConfig,
-    AutoModelForCausalLM,
-    AutoModelForTokenClassification,
-    AutoTokenizer,
-)
-import torch.nn.functional as F
-import string
-import copy
-from torch.utils.data import DataLoader
-
-from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token
-
-
-class PromptCompressor:
-    """
-    PromptCompressor is designed for compressing prompts based on a given language model.
-
-    This class initializes with the language model and its configuration, preparing it for prompt compression tasks.
-    The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing.
-    Users can specify different model names and configurations as needed for their particular use case.The architecture is
-    based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu,
-    Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models."
-    arXiv preprint arXiv:2310.05736 (2023).
-
-    Args:
-        model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf".
-        device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
-        model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
-        open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
-        use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper 
-            "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression".
-            Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang. 
-            "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:,
-            Default is True.
-        llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is 
-            {
-                "max_batch_size": 50, 
-                "max_force_token": 100, # max number of the tokens which will be forcely preserved
-            }
-    Example:
-        >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, )
-        >>> context = ["This is the first context sentence.", "Here is another context sentence."]
-        >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5)
-        >>> print(result["compressed_prompt"])
-        # This will print the compressed version of the context.
-
-    Note:
-        The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models.
-    """
-
-    def __init__(
-        self,
-        model_name: str = "NousResearch/Llama-2-7b-hf",
-        device_map: str = "cuda",
-        model_config: dict = {},
-        open_api_config: dict = {},
-        use_llmlingua2: bool = True,
-        llmlingua2_config: dict = {},
-    ):
-        self.model_name = model_name
-        self.use_llmlingua2 = use_llmlingua2
-        self.retrieval_model = None
-        self.retrieval_model_name = None
-        self.open_api_config = open_api_config
-        self.cache_bos_num = 10
-        self.prefix_bos_num = 100
-        self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
-
-        self.load_model(model_name, device_map, model_config)
-        if use_llmlingua2:
-            self.init_llmlingua2(**llmlingua2_config)
-
-    def init_llmlingua2(
-        self,
-        max_batch_size: int = 50,
-        max_force_token: int = 100,
-    ):
-
-        seed_everything(42)
-        self.max_batch_size = max_batch_size
-        self.max_seq_len = 512
-        self.max_force_token = max_force_token
-        self.special_tokens = set(self.tokenizer.special_tokens_map.values())
-
-        self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)]
-        self.tokenizer.add_special_tokens(
-            {"additional_special_tokens": self.added_tokens}
-        )
-        self.model.resize_token_embeddings(len(self.tokenizer))
-
-    def load_model(
-        self, model_name: str, device_map: str = "cuda", model_config: dict = {}
-    ):
-        trust_remote_code = model_config.get("trust_remote_code", True)
-        if "trust_remote_code" not in model_config:
-            model_config["trust_remote_code"] = trust_remote_code
-        config = AutoConfig.from_pretrained(model_name, **model_config)
-        tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config)
-        if model_config.get("pad_to_left", True):
-            tokenizer.padding_side = "left"
-            tokenizer.pad_token_id = (
-                config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
-            )
-        MODEL_CLASS = (
-            AutoModelForTokenClassification
-            if any("ForTokenClassification" in ar for ar in config.architectures)
-            else AutoModelForCausalLM
-        )
-        self.device = (
-            device_map
-            if any(key in device_map for key in ["cuda", "cpu", "mps"])
-            else "cuda"
-        )
-        if "cuda" in device_map or "cpu" in device_map:
-            model = MODEL_CLASS.from_pretrained(
-                model_name,
-                torch_dtype=model_config.get(
-                    "torch_dtype", "auto" if device_map == "cuda" else torch.float32
-                ),
-                device_map=device_map,
-                config=config,
-                ignore_mismatched_sizes=True,
-                **model_config,
-            )
-        else:
-            model = MODEL_CLASS.from_pretrained(
-                model_name,
-                device_map=device_map,
-                torch_dtype=model_config.get("torch_dtype", "auto"),
-                pad_token_id=tokenizer.pad_token_id,
-                **model_config,
-            )
-        self.tokenizer = tokenizer
-        self.model = model
-        self.context_idxs = []
-        self.max_position_embeddings = config.max_position_embeddings
-
-    def get_ppl(
-        self,
-        text: str,
-        granularity: str = "sentence",
-        input_ids=None,
-        attention_mask=None,
-        past_key_values=None,
-        return_kv=False,
-        end=None,
-        condition_mode: str = "none",
-        condition_pos_id: int = 0,
-    ):
-        if input_ids is None:
-            tokenized_text = self.tokenizer(text, return_tensors="pt")
-            input_ids = tokenized_text["input_ids"].to(self.device)
-            attention_mask = tokenized_text["attention_mask"].to(self.device)
-        if past_key_values is not None:
-            past_length = past_key_values[0][0].shape[2]
-        else:
-            past_length = 0
-        if end is None:
-            end = input_ids.shape[1]
-        end = min(end, past_length + self.max_position_embeddings)
-        with torch.no_grad():
-            response = self.model(
-                input_ids[:, past_length:end],
-                attention_mask=attention_mask[:, :end],
-                past_key_values=past_key_values,
-                use_cache=True,
-            )
-            past_key_values = response.past_key_values
-
-        shift_logits = response.logits[..., :-1, :].contiguous()
-        shift_labels = input_ids[..., past_length + 1 : end].contiguous()
-        # Flatten the tokens
-        active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
-        active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
-        active_labels = shift_labels.view(-1)[active]
-        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
-        loss = loss_fct(active_logits, active_labels)
-        if condition_mode == "before":
-            loss = loss[:condition_pos_id]
-        elif condition_mode == "after":
-            loss = loss[condition_pos_id:]
-        res = loss.mean() if granularity == "sentence" else loss
-        return (res, past_key_values) if return_kv else res
-
-    def __call__(self, *args, **kwargs):
-        return self.compress_prompt(*args, **kwargs)
-
-    def structured_compress_prompt(
-        self,
-        context: List[str],
-        instruction: str = "",
-        question: str = "",
-        rate: float = 0.5,
-        target_token: float = -1,
-        iterative_size: int = 200,
-        force_context_ids: List[int] = None,
-        force_context_number: int = None,
-        use_sentence_level_filter: bool = False,
-        use_context_level_filter: bool = True,
-        use_token_level_filter: bool = True,
-        keep_split: bool = False,
-        keep_first_sentence: int = 0,
-        keep_last_sentence: int = 0,
-        keep_sentence_number: int = 0,
-        high_priority_bonus: int = 100,
-        context_budget: str = "+100",
-        token_budget_ratio: float = 1.4,
-        condition_in_question: str = "none",
-        reorder_context: str = "original",
-        dynamic_context_compression_ratio: float = 0.0,
-        condition_compare: bool = False,
-        add_instruction: bool = False,
-        rank_method: str = "llmlingua",
-        concate_question: bool = True,
-    ):
-        """
-        Compresses the given prompt context based on a specified structure.
-
-        Each element of context should be segmented using one or more non-nested '<llmlingua></llmlingua>' tags.
-        Each '<llmlingua>' tag can include optional parameters 'rate' and 'compress' (e.g., '<llmlingua, rate=0.3, compress=True>'),
-        indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'.
-        When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment.
-
-        Args:
-            context (List[str]): List of context strings divided by '<llmlingua></llmlingua>' tags with optional compression settings.
-            instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
-            question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
-            rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression".
-                Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern,
-                Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023):
-                .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
-                Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
-                fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
-                to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter
-                and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate.
-                However, for segments where no specific rate is defined, the global rate serves as the default value. The final
-                compression rate of the entire text is a composite result of multiple compression rates applied across different sections.
-            target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no
-                specific target. The actual number of tokens after compression should generally be less than the specified target_token,
-                but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
-                the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level
-                filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token.
-                However, for segments where no specific rate is defined, the global rate calculated from global target token serves
-                as the default value. The final target token of the entire text is a composite result of multiple compression rates
-                applied across different sections.
-            iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
-            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
-            force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
-            use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
-            use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
-            use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
-            keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
-            keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
-            keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
-            keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
-            high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
-            context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
-            token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
-            condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
-            reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
-            dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
-            condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
-            add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
-            rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
-            concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
-
-        Returns:
-            dict: A dictionary containing:
-                - "compressed_prompt" (str): The resulting compressed prompt.
-                - "origin_tokens" (int): The original number of tokens in the input.
-                - "compressed_tokens" (int): The number of tokens in the compressed output.
-                - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
-                - "rate" (str): The compression rate achieved, in a human-readable format.
-                - "saving" (str): Estimated savings in GPT-4 token usage.
-        """
-        if not context:
-            context = [" "]
-        if isinstance(context, str):
-            context = [context]
-        context = [
-            self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids)
-            for c in context
-        ]
-        context_tokens_length = [self.get_token_length(c) for c in context]
-        instruction_tokens_length, question_tokens_length = self.get_token_length(
-            instruction
-        ), self.get_token_length(question)
-        if target_token == -1:
-            target_token = (
-                (
-                    instruction_tokens_length
-                    + question_tokens_length
-                    + sum(context_tokens_length)
-                )
-                * rate
-                - instruction_tokens_length
-                - (question_tokens_length if concate_question else 0)
-            )
-        else:
-            rate = target_token / sum(context_tokens_length)
-        (
-            context,
-            context_segs,
-            context_segs_rate,
-            context_segs_compress,
-        ) = self.segment_structured_context(context, rate)
-        return self.compress_prompt(
-            context,
-            instruction,
-            question,
-            rate,
-            target_token,
-            iterative_size,
-            force_context_ids,
-            force_context_number,
-            use_sentence_level_filter,
-            use_context_level_filter,
-            use_token_level_filter,
-            keep_split,
-            keep_first_sentence,
-            keep_last_sentence,
-            keep_sentence_number,
-            high_priority_bonus,
-            context_budget,
-            token_budget_ratio,
-            condition_in_question,
-            reorder_context,
-            dynamic_context_compression_ratio,
-            condition_compare,
-            add_instruction,
-            rank_method,
-            concate_question,
-            context_segs=context_segs,
-            context_segs_rate=context_segs_rate,
-            context_segs_compress=context_segs_compress,
-        )
-
-    def compress_prompt(
-        self,
-        context: List[str],
-        instruction: str = "",
-        question: str = "",
-        rate: float = 0.5,
-        target_token: float = -1,
-        iterative_size: int = 200,
-        force_context_ids: List[int] = None,
-        force_context_number: int = None,
-        use_sentence_level_filter: bool = False,
-        use_context_level_filter: bool = True,
-        use_token_level_filter: bool = True,
-        keep_split: bool = False,
-        keep_first_sentence: int = 0,
-        keep_last_sentence: int = 0,
-        keep_sentence_number: int = 0,
-        high_priority_bonus: int = 100,
-        context_budget: str = "+100",
-        token_budget_ratio: float = 1.4,
-        condition_in_question: str = "none",
-        reorder_context: str = "original",
-        dynamic_context_compression_ratio: float = 0.0,
-        condition_compare: bool = False,
-        add_instruction: bool = False,
-        rank_method: str = "llmlingua",
-        concate_question: bool = True,
-        context_segs: List[str] = None,
-        context_segs_rate: List[float] = None,
-        context_segs_compress: List[bool] = None,
-        target_context: int = -1,
-        context_level_rate: float = 1.0,
-        context_level_target_token: int = -1,
-        return_word_label: bool = False,
-        word_sep: str = "\t\t|\t\t",
-        label_sep: str = " ",
-        token_to_word: str = "mean",
-        force_tokens: List[str] = [],
-        force_reserve_digit: bool = False,
-        drop_consecutive: bool = False,
-        chunk_end_tokens: List[str] = [".", "\n"],
-    ):
-        """
-        Compresses the given context.
-
-        Args:
-            context (List[str]): List of context strings that form the basis of the prompt.
-            instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
-            question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
-            rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined
-                the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne,
-                Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression."
-                arXiv preprint arXiv:2309.10668 (2023):
-                .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
-                Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
-                fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
-                to 1.0, representing the target compression rate.
-            target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
-                The actual number of tokens after compression should generally be less than the specified target_token, but there can
-                be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
-                the sole criterion, overriding the ``rate``.
-            iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
-            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
-            force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
-            use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
-            use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
-            use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
-            keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
-            keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
-            keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
-            keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
-            high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
-            context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
-            token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
-            condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
-            reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
-            dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
-            condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
-            add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
-            rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
-            concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
-
-            target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
-            context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
-            context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
-                Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
-            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
-            return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
-            word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
-            label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label.  Default is " ".
-            token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
-            force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
-            force_reserve_digit  (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
-            drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. 
-                Default is False.
-            chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"],
-        Returns:
-            dict: A dictionary containing:
-                - "compressed_prompt" (str): The resulting compressed prompt.
-                - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2.
-                - "fn_labeled_original_prompt" (str): original words along with their labels
-                    indicating whether to reserve in compressed prompt, in the format (word label_sep label)
-                    Only used in llmlingua2 when return_word_label = True.
-                - "origin_tokens" (int): The original number of tokens in the input.
-                - "compressed_tokens" (int): The number of tokens in the compressed output.
-                - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
-                - "rate" (str): The compression rate achieved, in a human-readable format.
-                - "saving" (str): Estimated savings in GPT-4 token usage.
-        """
-        if self.use_llmlingua2:
-            return self.compress_prompt_llmlingua2(
-                context,
-                rate=rate,
-                target_token=target_token,
-                use_context_level_filter=use_context_level_filter,
-                use_token_level_filter=use_token_level_filter,
-                target_context=target_context,
-                context_level_rate=context_level_rate,
-                context_level_target_token=context_level_target_token,
-                force_context_ids=force_context_ids,
-                return_word_label=return_word_label,
-                word_sep=word_sep,
-                label_sep=label_sep,
-                token_to_word=token_to_word,
-                force_tokens=force_tokens,
-                force_reserve_digit=force_reserve_digit,
-                drop_consecutive=drop_consecutive,
-                chunk_end_tokens=chunk_end_tokens,
-            )
-        assert (
-            rate <= 1.0
-        ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
-
-        if not context:
-            context = [" "]
-        if isinstance(context, str):
-            context = [context]
-        assert not (
-            rank_method == "longllmlingua" and not question
-        ), "In the LongLLMLingua, it is necessary to set a question."
-        if condition_compare and "_condition" not in condition_in_question:
-            condition_in_question += "_condition"
-        if rank_method == "longllmlingua":
-            if condition_in_question == "none":
-                condition_in_question = "after"
-        elif rank_method == "llmlingua":
-            condition_in_question = (
-                "none"
-                if "_condition" not in condition_in_question
-                else "none_condition"
-            )
-        origin_tokens = len(
-            self.oai_tokenizer.encode(
-                "\n\n".join([instruction] + context + [question]).strip()
-            )
-        )
-        context_tokens_length = [self.get_token_length(c) for c in context]
-        instruction_tokens_length, question_tokens_length = self.get_token_length(
-            instruction
-        ), self.get_token_length(question)
-        if target_token == -1:
-            target_token = (
-                (
-                    instruction_tokens_length
-                    + question_tokens_length
-                    + sum(context_tokens_length)
-                )
-                * rate
-                - instruction_tokens_length
-                - (question_tokens_length if concate_question else 0)
-            )
-        condition_flag = "_condition" in condition_in_question
-        condition_in_question = condition_in_question.replace("_condition", "")
-
-        if len(context) > 1 and use_context_level_filter:
-            context, dynamic_ratio, context_used = self.control_context_budget(
-                context,
-                context_tokens_length,
-                target_token,
-                force_context_ids,
-                force_context_number,
-                question,
-                condition_in_question,
-                reorder_context=reorder_context,
-                dynamic_context_compression_ratio=dynamic_context_compression_ratio,
-                rank_method=rank_method,
-                context_budget=context_budget,
-                context_segs=context_segs,
-                context_segs_rate=context_segs_rate,
-                context_segs_compress=context_segs_compress,
-            )
-            if context_segs is not None:
-                context_segs = [context_segs[idx] for idx in context_used]
-                context_segs_rate = [context_segs_rate[idx] for idx in context_used]
-                context_segs_compress = [
-                    context_segs_compress[idx] for idx in context_used
-                ]
-        else:
-            dynamic_ratio = [0.0] * len(context)
-
-        segments_info = []
-        if use_sentence_level_filter:
-            context, segments_info = self.control_sentence_budget(
-                context,
-                target_token,
-                keep_first_sentence=keep_first_sentence,
-                keep_last_sentence=keep_last_sentence,
-                keep_sentence_number=keep_sentence_number,
-                high_priority_bonus=high_priority_bonus,
-                token_budget_ratio=token_budget_ratio,
-                question=question,
-                condition_in_question=condition_in_question,
-                rank_method=rank_method,
-                context_segs=context_segs,
-                context_segs_rate=context_segs_rate,
-                context_segs_compress=context_segs_compress,
-            )
-        elif context_segs is not None:
-            for context_idx in range(len(context)):
-                segments_info.append(
-                    [
-                        (len(seg_text), seg_rate, seg_compress)
-                        for seg_text, seg_rate, seg_compress in zip(
-                            context_segs[context_idx],
-                            context_segs_rate[context_idx],
-                            context_segs_compress[context_idx],
-                        )
-                    ]
-                )
-        segments_info = [
-            self.concate_segment_info(segment_info) for segment_info in segments_info
-        ]
-
-        if condition_flag:
-            prefix = question + "\n\n" + instruction if add_instruction else question
-            if (
-                self.get_token_length(prefix + "\n\n") + iterative_size * 2
-                > self.max_position_embeddings
-            ):
-                tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
-                prefix = self.tokenizer.decode(
-                    tokens[: self.prefix_bos_num]
-                    + tokens[
-                        len(tokens)
-                        - self.max_position_embeddings
-                        + 2
-                        + self.prefix_bos_num
-                        + 2 * iterative_size :
-                    ]
-                )
-            start = self.get_prefix_length(prefix + "\n\n", context[0])
-            context = [prefix] + context
-        else:
-            start = 0
-
-        if use_token_level_filter:
-            context = self.iterative_compress_prompt(
-                context,
-                target_token,
-                iterative_size=iterative_size,
-                keep_split=keep_split,
-                start=start,
-                dynamic_ratio=dynamic_ratio,
-                condition_compare=condition_compare,
-                segments_info=segments_info,
-            )
-            compressed_prompt = (
-                self.tokenizer.batch_decode(context[0])[0]
-                .replace("<s> ", "")
-                .replace("<s>", "")
-            )
-        else:
-            if condition_flag:
-                context = context[1:]
-            compressed_prompt = "\n\n".join(context)
-
-        res = []
-        if instruction:
-            res.append(instruction)
-        if compressed_prompt.strip():
-            res.append(compressed_prompt)
-        if question and concate_question:
-            res.append(question)
-
-        compressed_prompt = "\n\n".join(res)
-
-        compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt))
-        saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
-        ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens
-        rate = 1 / ratio
-        return {
-            "compressed_prompt": compressed_prompt,
-            "origin_tokens": origin_tokens,
-            "compressed_tokens": compressed_tokens,
-            "ratio": f"{ratio:.1f}x",
-            "rate": f"{rate * 100:.1f}%",
-            "saving": f", Saving ${saving:.1f} in GPT-4.",
-        }
-
-    def compress_prompt_llmlingua2(
-        self,
-        context: List[str],
-        rate: float = 0.5,
-        target_token: int = -1,
-        use_context_level_filter: bool = False,
-        use_token_level_filter: bool = True,
-        target_context: int = -1,
-        context_level_rate: float = 1.0,
-        context_level_target_token: int = -1,
-        force_context_ids: List[int] = [],
-        return_word_label: bool = False,
-        word_sep: str = "\t\t|\t\t",
-        label_sep: str = " ",
-        token_to_word: str = "mean",
-        force_tokens: List[str] = [],
-        force_reserve_digit: bool = False,
-        drop_consecutive: bool = False,
-        chunk_end_tokens: List[str] = [".", "\n"],
-    ):
-        """
-        Compresses the given context, instruction and question.
-
-        Args:
-            context (List[str]): List of context strings that form the basis of the prompt.
-            rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate
-                generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
-                it should be a float greater than or equal to 1.0, representing the target compression rate.
-            target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
-                The actual number of tokens after compression should generally be less than the specified target_token, but there can
-                be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
-                the sole criterion, overriding the rate.
-            target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
-                Only used in the coarse-to-fine compression.
-            context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
-                Only used in the coarse-to-fine compression.
-            context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
-                Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
-            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
-            return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
-            word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
-            label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label.  Default is " ".
-            token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
-            force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
-            force_reserve_digit  (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
-            drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. 
-                Default is False.
-            chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"].
-        Returns:
-            dict: A dictionary containing:
-                - "compressed_prompt" (str): The resulting compressed prompt.
-                - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt.
-                - "fn_labeled_original_prompt" (str): original words along with their labels
-                    indicating whether to reserve in compressed prompt, in the format (word label_sep label)
-                - "origin_tokens" (int): The original number of tokens in the input.
-                - "compressed_tokens" (int): The number of tokens in the compressed output.
-                - "ratio" (str): The compression ratio achieved, in a human-readable format.
-                - "rate" (str): The compression rate achieved, in a human-readable format.
-                - "saving" (str): Estimated savings in GPT-4 token usage.
-
-        """
-        assert len(force_tokens) <= self.max_force_token
-        token_map = {}
-        for i, t in enumerate(force_tokens):
-            if len(self.tokenizer.tokenize(t)) != 1:
-                token_map[t] = self.added_tokens[i]
-        chunk_end_tokens = copy.deepcopy(chunk_end_tokens)
-        for c in chunk_end_tokens:
-            if c in token_map:
-                chunk_end_tokens.append(token_map[c])
-        chunk_end_tokens = set(chunk_end_tokens)
-
-        if type(context) == str:
-            context = [context]
-        context = copy.deepcopy(context)
-
-        if len(context) == 1 and use_context_level_filter:
-            use_context_level_filter = False
-
-        n_original_token = 0
-        context_chunked = []
-        for i in range(len(context)):
-            n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True)
-            for ori_token, new_token in token_map.items():
-                context[i] = context[i].replace(ori_token, new_token)
-            context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens))
-
-        if use_context_level_filter:
-            # want use_context_level_filter but do not specify any parameters in context level?
-            # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token
-            if (
-                target_context <= 0
-                and context_level_rate >= 1.0
-                and context_level_target_token <= 0
-            ):
-                if target_token < 0 and rate < 1.0:
-                    context_level_rate = (
-                        (rate + 1.0) / 2 if use_token_level_filter else rate
-                    )
-                    print(
-                        f"set context level compression rate to {context_level_rate}."
-                    )
-                if target_token >= 0:
-                    context_level_target_token = (
-                        target_token * 2 if use_token_level_filter else target_token
-                    )
-                    print(
-                        f"set context level target token to {context_level_target_token}."
-                    )
-
-            if target_context >= 0:
-                context_level_rate = min(target_context / len(context), 1.0)
-                # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.')
-            if context_level_target_token >= 0:
-                context_level_rate = min(
-                    context_level_target_token / n_original_token, 1.0
-                )
-                # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.')
-
-            context_probs, context_words = self.__get_context_prob(
-                context_chunked, 
-                token_to_word=token_to_word, 
-                force_tokens=force_tokens,
-                token_map=token_map, 
-                force_reserve_digit=force_reserve_digit, 
-            )
-
-            threshold = np.percentile(
-                context_probs, int(100 * (1 - context_level_rate))
-            )
-
-            reserved_context = []
-            context_label = [False] * len(context_probs)
-            for i, p in enumerate(context_probs):
-                if p >= threshold or (
-                    force_context_ids is not None and i in force_context_ids
-                ):
-                    reserved_context.append(context_chunked[i])
-                    context_label[i] = True
-            n_reserved_token = 0
-            for chunks in reserved_context:
-                for c in chunks:
-                    n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
-            if target_token >= 0:
-                rate = min(target_token / n_reserved_token, 1.0)
-                print(
-                    f"override compression rate to {rate} because you specified target_token = {target_token}."
-                )
-
-            if use_token_level_filter:
-                compressed_context, word_list, word_label_list = self.__compress(
-                    reserved_context,
-                    reduce_rate=max(0, 1 - rate),
-                    token_to_word=token_to_word,
-                    force_tokens=force_tokens,
-                    token_map=token_map, 
-                    force_reserve_digit=force_reserve_digit, 
-                    drop_consecutive=drop_consecutive,
-                )
-            else:
-                compressed_context, word_list, word_label_list = self.__compress(
-                    reserved_context, 
-                    reduce_rate=0, 
-                    token_to_word=token_to_word,
-                    force_tokens=force_tokens,
-                    token_map=token_map, 
-                    force_reserve_digit=force_reserve_digit, 
-                    drop_consecutive=drop_consecutive,
-                )
-                print(
-                    "return the original text because you specify use_token_level_filter=False"
-                )
-
-            n_compressed_token = 0
-            for c in compressed_context:
-                n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
-            saving = (n_original_token - n_compressed_token) * 0.06 / 1000
-            ratio = (
-                1 if n_compressed_token == 0 else n_original_token / n_compressed_token
-            )
-            res = {
-                "compressed_prompt": "\n\n".join(compressed_context),
-                "compressed_prompt_list": compressed_context,
-                "origin_tokens": n_original_token,
-                "compressed_tokens": n_compressed_token,
-                "ratio": f"{ratio:.1f}x",
-                "rate": f"{1 / ratio * 100:.1f}%",
-                "saving": f", Saving ${saving:.1f} in GPT-4.",
-            }
-            if return_word_label:
-                words = []
-                labels = []
-                j = 0
-                for i in range(len(context)):
-                    if context_label[i]:
-                        words.extend(word_list[j])
-                        labels.extend(word_label_list[j])
-                        j += 1
-                    else:
-                        words.extend(context_words[i])
-                        labels.extend([0] * len(context_words[i]))
-                word_label_lines = word_sep.join(
-                    [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
-                )
-                res["fn_labeled_original_prompt"] = word_label_lines
-            return res
-
-        if target_token > 0:
-            rate = min(target_token / n_original_token, 1.0)
-            print(
-                f"override compression rate to {rate} \
-                  because you specified target_token = {target_token}."
-            )
-
-        if use_token_level_filter:
-            compressed_context, word_list, word_label_list = self.__compress(
-                context_chunked,
-                reduce_rate=max(0, 1 - rate),
-                token_to_word=token_to_word,
-                force_tokens=force_tokens,
-                token_map=token_map, 
-                force_reserve_digit=force_reserve_digit, 
-                drop_consecutive=drop_consecutive,
-            )
-        else:
-            compressed_context, word_list, word_label_list = self.__compress(
-                context_chunked, 
-                reduce_rate=0, 
-                token_to_word=token_to_word,
-                force_tokens=force_tokens,
-                token_map=token_map, 
-                force_reserve_digit=force_reserve_digit, 
-                drop_consecutive=drop_consecutive,
-            )
-            print(
-                "return the original text because you specify use_token_level_filter=False"
-            )
-
-        n_compressed_token = 0
-        for c in compressed_context:
-            n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
-        saving = (n_original_token - n_compressed_token) * 0.06 / 1000
-        ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
-        res = {
-            "compressed_prompt": "\n\n".join(compressed_context),
-            "compressed_prompt_list": compressed_context,
-            "origin_tokens": n_original_token,
-            "compressed_tokens": n_compressed_token,
-            "ratio": f"{ratio:.1f}x",
-            "rate": f"{1 / ratio * 100:.1f}%",
-            "saving": f", Saving ${saving:.1f} in GPT-4.",
-        }
-        if return_word_label:
-            words = []
-            labels = []
-            for w_list, l_list in zip(word_list, word_label_list):
-                words.extend(w_list)
-                labels.extend(l_list)
-
-            # new_words = []
-            # new_labels = []
-            # for i in range(len(words)):
-            #     word, label = words[i], labels[i]
-            #     if word in string.punctuation:
-            #         if labels[i-1] == 1 and label == 1 and i > 0:
-            #             new_words[-1] += word
-            #     else:
-            #         new_words.append(word)
-            #         new_labels.append(label)
-            # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)])
-
-            word_label_lines = word_sep.join(
-                [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
-            )
-            res["fn_labeled_original_prompt"] = word_label_lines
-        return res
-
-    def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False):
-        if use_oai_tokenizer:
-            return len(self.oai_tokenizer.encode(text))
-        else:
-            return len(
-                self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
-            )
-
-    def get_prefix_length(self, prefix: str, text: str):
-        possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
-        full_input_ids = self.tokenizer(
-            prefix + text[:100], add_special_tokens=False
-        ).input_ids
-        for i in range(possible_prefix_token, len(full_input_ids)):
-            cur_prefix = self.tokenizer.decode(full_input_ids[:i])
-            if cur_prefix == prefix:
-                break
-        assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
-        return i
-
-    def get_condition_ppl(
-        self,
-        text: str,
-        question: str,
-        condition_in_question: str = "none",
-        granularity: str = "sentence",
-    ):
-        if condition_in_question == "none":
-            return self.get_ppl(text, granularity=granularity)
-        elif condition_in_question == "before":
-            return self.get_ppl(
-                question + text,
-                granularity=granularity,
-                condition_mode="after",
-                condition_pos_id=self.get_token_length(question) - 1,
-            )
-        elif condition_in_question == "after":
-            return self.get_ppl(
-                text + question,
-                granularity=granularity,
-                condition_mode="after",
-                condition_pos_id=self.get_token_length(text) - 1,
-            )
-
-    def get_dynamic_compression_ratio(
-        self,
-        context: list,
-        target_token: float,
-        iterative_size: int,
-        dynamic_ratio: list,
-        start: int,
-        seg_info: List[List[tuple]] = None,
-    ):
-        def get_ratio(base: float, delta: float):
-            return max(min(1, base + delta), 0)
-
-        context_length = [self.get_token_length(ii, False) + 2 for ii in context]
-        if start:
-            context_length = context_length[1:]
-        tau = target_token / (sum(context_length) + 1)
-        res, idx, last, last_target = [], 0, 1, []
-        while idx < len(context_length):
-            if last + context_length[idx] >= iterative_size:
-                last_target.append(
-                    (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
-                )
-                res.append(last_target)
-                last = last + context_length[idx] - iterative_size
-                if last > iterative_size:
-                    k = last // iterative_size
-                    res.extend(
-                        [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
-                    )
-                    last -= k * iterative_size
-
-                last_target = (
-                    [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
-                )
-            else:
-                last += context_length[idx]
-                last_target.append(
-                    (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
-                )
-            idx += 1
-        if last_target:
-            res.append(last_target)
-        return res
-
-    def get_structured_dynamic_compression_ratio(
-        self,
-        context: list,
-        iterative_size: int,
-        dynamic_ratio: list,
-        start: int,
-        seg_info: List[List[tuple]] = None,
-    ):
-        if start:
-            pure_context = context[1:]
-        else:
-            pure_context = context
-        global_dynamic_rate, global_dynamic_compress, segments = [], [], []
-        for context_idx, text in enumerate(pure_context):
-            text_seen = 0
-            for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate(
-                seg_info[context_idx]
-            ):
-                seg_text = text[text_seen : text_seen + seg_len]
-                if (
-                    seg_idx == len(seg_info[context_idx]) - 1
-                    and context_idx != len(pure_context) - 1
-                ):
-                    seg_text += "\n\n"
-                segments.append(seg_text)
-                if seg_compress:
-                    global_dynamic_rate.append(seg_rate)
-                else:
-                    global_dynamic_rate.append(1.0)
-                global_dynamic_compress.append(seg_compress)
-                text_seen += seg_len
-        origin_text = "\n\n".join(pure_context)
-        assert len("".join(segments)) == len(origin_text)
-        assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)
-
-        text_input_ids = self.tokenizer(
-            "\n\n".join(context), add_special_tokens=False
-        ).input_ids[start:]
-        assert self.tokenizer.decode(text_input_ids) == origin_text
-        dynamic_compression_ratio = self.token_segment(
-            text_input_ids,
-            iterative_size,
-            segments,
-            global_dynamic_rate,
-            global_dynamic_compress,
-        )
-        return dynamic_compression_ratio
-
-    def token_segment(
-        self,
-        text_input_ids: List[int],
-        iterative_size: int,
-        segments: List[str],
-        global_dynamic_rate: List[float],
-        global_dynamic_compress: List[bool],
-    ):
-        decode_window = 3
-        seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1
-        dynamic_compression_rate, local_compresssion_rate = [], []
-        for i in range(len(text_input_ids)):
-            if i < decode_window:
-                id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1]
-            else:
-                id_pre, id_cur = (
-                    text_input_ids[i - decode_window + 1 : i],
-                    text_input_ids[i - decode_window + 1 : i + 1],
-                )
-            cur_word = self.tokenizer.decode(id_cur)[
-                len(self.tokenizer.decode(id_pre)) :
-            ]
-            cur_word_len = len(cur_word)
-            if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen:
-                possible_rate, possible_compress = [], []
-                while (
-                    cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen
-                ):
-                    possible_rate.append(global_dynamic_rate[seg_idx])
-                    possible_compress.append(global_dynamic_compress[seg_idx])
-                    cur_word_len -= len(segments[seg_idx]) - seg_seen
-                    seg_idx += 1
-                    seg_seen = 0
-                if cur_word_len:
-                    possible_rate.append(global_dynamic_rate[seg_idx])
-                    possible_compress.append(global_dynamic_compress[seg_idx])
-                new_rate = 1.0 if False in possible_compress else min(possible_rate)
-            else:
-                new_rate = global_dynamic_rate[seg_idx]
-            if new_rate != last_rate and i - token_seen_num:
-                local_compresssion_rate.append((i - token_seen_num, last_rate))
-                token_seen_num = i
-            last_rate = new_rate
-            seg_seen += cur_word_len
-            if (i + 1) % iterative_size == 0:
-                if token_seen_num != i + 1:
-                    local_compresssion_rate.append((i + 1 - token_seen_num, last_rate))
-                    token_seen_num = i + 1
-                dynamic_compression_rate.append(local_compresssion_rate[:])
-                local_compresssion_rate = []
-        if token_seen_num != len(text_input_ids):
-            local_compresssion_rate.append(
-                (len(text_input_ids) - token_seen_num, last_rate)
-            )
-        if local_compresssion_rate != []:
-            dynamic_compression_rate.append(local_compresssion_rate[:])
-        return dynamic_compression_rate
-
-    def control_context_budget(
-        self,
-        context: List[str],
-        context_tokens_length: List[int],
-        target_token: float,
-        force_context_ids: List[int] = None,
-        force_context_number: int = None,
-        question: str = "",
-        condition_in_question: str = "none",
-        reorder_context: str = "original",
-        dynamic_context_compression_ratio: float = 0.0,
-        rank_method: str = "longllmlingua",
-        context_budget: str = "+100",
-        context_segs: List[List[str]] = None,
-        context_segs_rate: List[List[float]] = None,
-        context_segs_compress: List[List[bool]] = None,
-    ):
-        demostrations_sort = self.get_rank_results(
-            context,
-            question,
-            rank_method,
-            condition_in_question,
-            context_tokens_length,
-        )
-
-        if target_token < 0:
-            target_token = 100
-        target_token = eval("target_token" + context_budget)
-        res = []
-        used = force_context_ids if force_context_ids is not None else []
-        if context_segs is not None:
-            for idx, _ in enumerate(context):
-                if False in context_segs_compress[idx]:
-                    used.append(idx)
-
-        self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
-        for idx, _ in demostrations_sort:
-            if idx >= len(context_tokens_length):
-                continue
-            target_token -= context_tokens_length[idx]
-            if idx not in used:
-                used.append(idx)
-            if target_token < 0 or (
-                force_context_number is not None and len(res) >= force_context_number
-            ):
-                break
-        original_used = used
-        if reorder_context == "original":
-            used = sorted(used)
-        elif reorder_context == "two_stage":
-            l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
-                _ for idx, _ in enumerate(used) if idx % 2 == 1
-            ]
-            used = l + r[::-1]
-
-        if dynamic_context_compression_ratio > 0:
-            N = len(used)
-            dynamic_ratio = [
-                i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
-                for i in range(-(N - 1), N, 2)
-            ][::-1]
-            dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
-            dynamic_ratio = [dynamic_ratio_map[i] for i in used]
-        else:
-            dynamic_ratio = [0.0] * len(used)
-
-        res = [context[idx] for idx in used if idx < len(context)]
-        return res, dynamic_ratio, used
-
-    def control_sentence_budget(
-        self,
-        context: List[str],
-        target_token: float,
-        keep_first_sentence: int = 0,
-        keep_last_sentence: int = 0,
-        keep_sentence_number: int = 0,
-        high_priority_bonus: int = 100,
-        token_budget_ratio: float = 1.4,
-        question: str = "",
-        condition_in_question: str = "none",
-        rank_method: str = "longllmlingua",
-        context_segs: List[List[str]] = None,
-        context_segs_rate: List[List[float]] = None,
-        context_segs_compress: List[List[bool]] = None,
-    ):
-        def keep_sentence(dem_idx: int, sent_keep: int):
-            idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
-            for idx in idxs:
-                sentence_ppl[idx] += high_priority_bonus
-
-        def sync_sentence(segments, text):
-            seg_num = len(segments)
-            new_segments = []
-            text_seen = 0
-            seg_idx, cur_seg_seen = 0, 0
-            for i, s in enumerate(text):
-                while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]:
-                    if cur_seg_seen < len(segments[seg_idx]) - 1:
-                        cur_seg_seen += 1
-                        continue
-                    new_segments.append(text[text_seen:i])
-                    text_seen = i
-                    seg_idx += 1
-                    cur_seg_seen = 0
-                cur_seg_seen += 1
-                if seg_idx == seg_num:
-                    break
-                if cur_seg_seen == len(segments[seg_idx]):
-                    new_segments.append(text[text_seen : i + 1])
-                    text_seen = i + 1
-                    seg_idx += 1
-                    cur_seg_seen = 0
-            if text_seen < len(text):
-                new_segments.append(text[text_seen:])
-            assert len("".join(new_segments)) == len(text)
-            return new_segments
-
-        sentences = [nltk.sent_tokenize(c) for c in context]
-        dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
-        for idx_d, s in enumerate(sentences):
-            for _ in s:
-                dem_g[idx_d].add(idx)
-                s2de[idx] = idx_d
-                idx += 1
-
-        if context_segs is not None:
-            context_segs = [
-                sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences)
-            ]
-            sen2seg_ratio = {}
-            idx = 0
-            for idx_d, sentences_each_context in enumerate(sentences):
-                segments_length = [len(s) for s in context_segs[idx_d]]
-                seg_idx, cur_seg_seen = 0, 0
-                for sentence in sentences_each_context:
-                    sentence_seg_ratio = []
-                    remain = len(sentence)
-                    while remain:
-                        if segments_length[seg_idx] - cur_seg_seen <= remain:
-                            new_seg_len = segments_length[seg_idx] - cur_seg_seen
-                            sentence_seg_ratio.append(
-                                (
-                                    new_seg_len,
-                                    context_segs_rate[idx_d][seg_idx],
-                                    context_segs_compress[idx_d][seg_idx],
-                                )
-                            )
-                            seg_idx += 1
-                            cur_seg_seen = 0
-                            remain -= new_seg_len
-                        else:
-                            sentence_seg_ratio.append(
-                                (
-                                    remain,
-                                    context_segs_rate[idx_d][seg_idx],
-                                    context_segs_compress[idx_d][seg_idx],
-                                )
-                            )
-                            cur_seg_seen += remain
-                            remain = 0
-                    sen2seg_ratio[idx] = sentence_seg_ratio
-                    idx += 1
-
-        context_sentences = [s for ii in sentences for s in ii]
-        sentence_tokens_length = [
-            self.get_token_length(sentence) for sentence in context_sentences
-        ]
-        N = len(context_sentences)
-        flags = list(range(len(context_sentences)))
-        if len(sentence_tokens_length) == 1:
-            return context
-        if rank_method == "longllmlingua":
-            sentence_ppl = [
-                self.get_condition_ppl(sentence, question, condition_in_question)
-                .cpu()
-                .numpy()
-                .item()
-                for sentence in context_sentences
-            ]
-            if keep_first_sentence:
-                sentence_ppl[:keep_first_sentence] = [
-                    ii + high_priority_bonus
-                    for ii in sentence_ppl[:keep_first_sentence]
-                ]
-            if keep_last_sentence:
-                sentence_ppl[-keep_last_sentence:] = [
-                    ii + high_priority_bonus
-                    for ii in sentence_ppl[-keep_last_sentence:]
-                ]
-            if keep_sentence_number:
-                for dem_idx in range(len(sentences)):
-                    keep_sentence(dem_idx, keep_sentence_number)
-            sort_direct = -1 if condition_in_question == "none" else 1
-            sent_sort = sorted(
-                enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
-            )
-        else:
-            sent_sort = self.get_rank_results(
-                context_sentences,
-                question,
-                rank_method,
-                condition_in_question,
-                [0] * len(context_sentences),
-            )
-
-        sentence_flags = [False] * N
-        if target_token < 0:
-            target_token = 100
-        target_token *= token_budget_ratio
-        res = []
-        for idx, _ in sent_sort:
-            idx = flags[idx]
-            target_token -= sentence_tokens_length[idx]
-            sentence_flags[idx] = True
-            if target_token < 0:
-                break
-
-        if context_segs is not None:
-            for idx in range(N):
-                preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]]
-                if False in preserved:
-                    sentence_flags[idx] = True
-
-        idx = 0
-        res = []
-        new_segments_info = []
-        for s in sentences:
-            tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
-            res.append("".join(tmp))
-            if context_segs is not None:
-                segment_ratio = []
-                for ii in range(len(s)):
-                    if sentence_flags[idx + ii]:
-                        segment_ratio.extend(sen2seg_ratio[idx + ii])
-                new_segments_info.append(segment_ratio)
-            idx += len(s)
-        if context_segs is not None:
-            new_segments_info = [
-                self.concate_segment_info(segment_info)
-                for segment_info in new_segments_info
-            ]
-        return res, new_segments_info
-
-    def get_compressed_input(
-        self,
-        loss,
-        input_ids,
-        attention_mask,
-        end=200,
-        iterative_size=200,
-        threshold=0.5,
-        keep_flag=None,
-        split_token_id: int = 13,
-        start: int = 0,
-        self_loss=None,
-        self_input_ids=None,
-        self_attention_mask=None,
-    ):
-        if self_loss is not None:
-            need_idx = torch.concat(
-                [
-                    loss[:start] > 0,
-                    self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
-                    loss[:1] > 0,
-                ]
-            )
-        else:
-            need_idx = torch.concat([loss > threshold, loss[:1] > 0])
-        need_idx[end:] = 1
-        need_idx[: end - iterative_size] = 1
-        loss = loss[need_idx[:-1]]
-        if self_loss is not None:
-            if need_idx.shape[0] < self_loss.shape[0] + start + 1:
-                need_idx = torch.cat(
-                    [
-                        need_idx,
-                        torch.ones(
-                            self_loss.shape[0] - need_idx.shape[0] + start + 1,
-                            dtype=torch.bool,
-                        ).to(need_idx.device),
-                    ]
-                )
-            self_loss = self_loss[need_idx[start:-1]]
-
-        if need_idx.shape[0] < input_ids.shape[1]:
-            need_idx = torch.cat(
-                [
-                    need_idx,
-                    torch.ones(
-                        input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
-                    ).to(need_idx.device),
-                ]
-            )
-        elif need_idx.shape[0] > input_ids.shape[1]:
-            need_idx = need_idx[: input_ids.shape[1]]
-
-        if keep_flag is not None:
-            need_idx[keep_flag == 1] = 1
-        last = -1
-        if keep_flag is not None:
-            for ii in range(max(0, end - iterative_size), end):
-                if need_idx[ii] != 1:
-                    continue
-                now = input_ids[0][ii].detach().cpu().item()
-                if (
-                    now == split_token_id
-                    and last == split_token_id
-                    and keep_flag[ii].detach().cpu().item() == 0
-                ):
-                    need_idx[ii] = 0
-                else:
-                    last = now
-        compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
-        compressed_attention_mask = attention_mask[attention_mask == 1][
-            need_idx
-        ].unsqueeze(0)
-
-        if self_loss is not None:
-            self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
-                need_idx[start:]
-            ].unsqueeze(0)
-            self_compressed_attention_mask = self_attention_mask[
-                self_attention_mask == 1
-            ][need_idx[start:]].unsqueeze(0)
-        else:
-            self_compressed_input_ids, self_compressed_attention_mask = None, None
-        if keep_flag is not None:
-            if len(keep_flag) > len(need_idx):
-                keep_flag = torch.cat(
-                    [
-                        keep_flag[:start],
-                        keep_flag[start : len(need_idx) + start][need_idx],
-                        keep_flag[start + len(need_idx) :],
-                    ]
-                )
-            else:
-                keep_flag = keep_flag[need_idx]
-        end -= (need_idx[:end] == 0).sum()
-        return (
-            compressed_input_ids,
-            compressed_attention_mask,
-            keep_flag,
-            end,
-            loss,
-            self_loss,
-            self_compressed_input_ids,
-            self_compressed_attention_mask,
-        )
-
-    def get_estimate_threshold_base_distribution(
-        self, ppl, ratio: float, condition_flag: bool = False
-    ):
-        if ratio == 1.0:
-            return float("-inf")
-        ppl = ppl[ppl != 10000]
-        target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
-        return (
-            ppl.sort(descending=not condition_flag)
-            .values[target_token]
-            .detach()
-            .cpu()
-            .item()
-        )
-
-    def iterative_compress_prompt(
-        self,
-        context: List[str],
-        target_token: float,
-        iterative_size: int = 200,
-        keep_split: bool = False,
-        split_token_id: int = 13,
-        start: int = 0,
-        dynamic_ratio: list = None,
-        condition_compare: bool = False,
-        segments_info: List[List[tuple]] = None,
-    ):
-        if segments_info is None or segments_info == []:
-            iterative_ratios = self.get_dynamic_compression_ratio(
-                context, target_token, iterative_size, dynamic_ratio, start
-            )
-        else:
-            iterative_ratios = self.get_structured_dynamic_compression_ratio(
-                context, iterative_size, dynamic_ratio, start, segments_info
-            )
-        context = "\n\n".join(context)
-        tokenized_text = self.tokenizer(
-            context, return_tensors="pt", add_special_tokens=False
-        )
-        input_ids = tokenized_text["input_ids"].to(self.device)
-        attention_mask = tokenized_text["attention_mask"].to(self.device)
-
-        N = (attention_mask == 1).sum()
-        compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
-        if condition_compare:
-            self_input_ids, self_attention_mask = (
-                input_ids[:, start:],
-                attention_mask[:, start:],
-            )
-            self_compressed_input_ids, self_compressed_attention_mask = (
-                self_input_ids,
-                self_attention_mask,
-            )
-
-        end = min(iterative_size + start, compressed_input_ids.shape[1])
-        threshold, keep_flag = None, None
-        if keep_split:
-            input_ids_numpy = input_ids.cpu().detach().numpy()[0]
-            N = len(input_ids_numpy)
-            keep_flag = [
-                int(
-                    (
-                        ii > 0
-                        and input_ids_numpy[ii] == split_token_id
-                        and input_ids_numpy[ii - 1] == split_token_id
-                    )
-                    or (
-                        ii < N - 1
-                        and input_ids_numpy[ii] == split_token_id
-                        and input_ids_numpy[ii + 1] == split_token_id
-                    )
-                )
-                for ii in range(N)
-            ]
-            keep_flag = torch.tensor(keep_flag).to(self.device)
-        past_key_values, past_loss, ready_end = None, None, 0
-        self_past_key_values, self_past_loss, self_ready_end = None, None, 0
-        pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
-        idx = 0
-        while end <= compressed_input_ids.shape[1]:
-            if end > self.max_position_embeddings and past_key_values is not None:
-                # KV-Cache Compression
-                e, s = end - self.max_position_embeddings, min(
-                    self.cache_bos_num + start, self.max_position_embeddings
-                )
-                if pop_compressed_input_ids is None:
-                    pop_compressed_input_ids = compressed_input_ids[:, :e]
-                else:
-                    pop_compressed_input_ids = torch.cat(
-                        [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
-                    )
-                compressed_input_ids = compressed_input_ids[:, e:]
-                compressed_attention_mask = compressed_attention_mask[:, e:]
-                past_key_values = [
-                    [
-                        torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
-                        torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
-                    ]
-                    for k, v in past_key_values
-                ]
-                if keep_flag is not None:
-                    keep_flag = keep_flag[e:]
-                end, ready_end = end - e, ready_end - e
-                if condition_compare:
-                    s = min(s, self_past_key_values[0][0].shape[2] - e)
-                    self_ready_end -= e
-                    if pop_self_compressed_input_ids is None:
-                        pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
-                    else:
-                        pop_self_compressed_input_ids = torch.cat(
-                            [
-                                pop_self_compressed_input_ids,
-                                self_compressed_input_ids[:, :e],
-                            ],
-                            dim=-1,
-                        )
-                    self_compressed_input_ids = self_compressed_input_ids[:, e:]
-                    self_compressed_attention_mask = self_compressed_attention_mask[
-                        :, e:
-                    ]
-                    self_past_key_values = [
-                        [
-                            torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
-                            torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
-                        ]
-                        for k, v in self_past_key_values
-                    ]
-
-            loss, past_key_values = self.get_ppl(
-                "",
-                "token",
-                compressed_input_ids,
-                compressed_attention_mask,
-                past_key_values=past_key_values,
-                return_kv=True,
-                end=end if idx else None,
-            )
-            if loss.shape[0] == 0:
-                break
-            if past_loss is not None:
-                if end - 1 > len(past_loss):
-                    past_loss = torch.cat(
-                        [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
-                    )
-                past_loss[ready_end : end - 1] = loss
-                loss = past_loss
-            else:
-                past_loss = loss
-            if idx:
-                past_key_values = [
-                    [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
-                    for k, v in past_key_values
-                ]
-            else:
-                past_key_values = None
-
-            if condition_compare:
-                self_loss, self_past_key_values = self.get_ppl(
-                    "",
-                    "token",
-                    self_compressed_input_ids,
-                    self_compressed_attention_mask,
-                    past_key_values=self_past_key_values,
-                    return_kv=True,
-                    end=end - start if idx else None,
-                )
-                if self_past_loss is not None:
-                    if end - start - 1 > len(self_past_loss):
-                        self_past_loss = torch.cat(
-                            [
-                                self_past_loss,
-                                torch.zeros_like(self_loss)[
-                                    : end - 1 - start - len(self_past_loss)
-                                ],
-                            ]
-                        )
-                    self_past_loss[self_ready_end : end - start - 1] = self_loss
-                    self_loss = self_past_loss
-                else:
-                    self_past_loss = self_loss
-                if idx:
-                    self_past_key_values = [
-                        [
-                            k[:, :, : end - iterative_size - start],
-                            v[:, :, : end - iterative_size - start],
-                        ]
-                        for k, v in self_past_key_values
-                    ]
-                else:
-                    self_past_key_values = None
-
-                self_ready_end = (
-                    end - start - iterative_size if not (start and idx == 0) else 0
-                )
-            ready_end = end - iterative_size if not (start and idx == 0) else 0
-
-            for delta_end, ratio in iterative_ratios[idx]:
-                loss = past_loss
-                if condition_compare:
-                    self_loss = self_past_loss
-                    threshold = self.get_estimate_threshold_base_distribution(
-                        self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
-                    )
-                else:
-                    threshold = self.get_estimate_threshold_base_distribution(
-                        loss, ratio, False
-                    )
-
-                (
-                    compressed_input_ids,
-                    compressed_attention_mask,
-                    keep_flag,
-                    end,
-                    past_loss,
-                    self_past_loss,
-                    self_compressed_input_ids,
-                    self_compressed_attention_mask,
-                ) = self.get_compressed_input(
-                    loss,
-                    compressed_input_ids,
-                    compressed_attention_mask,
-                    end - iterative_size + delta_end,
-                    iterative_size=delta_end,
-                    threshold=threshold,
-                    keep_flag=keep_flag,
-                    split_token_id=split_token_id,
-                    start=start,
-                    self_loss=self_loss if condition_compare else None,
-                    self_input_ids=(
-                        self_compressed_input_ids if condition_compare else None
-                    ),
-                    self_attention_mask=(
-                        self_compressed_attention_mask if condition_compare else None
-                    ),
-                )
-                end += iterative_size
-            idx += 1
-        if pop_compressed_input_ids is not None:
-            compressed_input_ids = torch.cat(
-                [pop_compressed_input_ids, compressed_input_ids], dim=-1
-            )
-        return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
-
-    def recover(
-        self,
-        original_prompt: str,
-        compressed_prompt: str,
-        response: str,
-    ):
-        def match_from_compressed(response_word):
-            response_input_ids = self.tokenizer(
-                response_word, add_special_tokens=False
-            )["input_ids"]
-            response_set, response_c = set(response_input_ids), defaultdict(list)
-            for idx in range(M):
-                if original_input_ids[idx] in response_set:
-                    response_c[original_input_ids[idx]].append(idx)
-            res, res_min, res_c = None, float("inf"), 1
-            n = len(response_input_ids)
-            for l in response_c[response_input_ids[0]]:
-                x, y, c = 0, l, 1
-                for x in range(1, n):
-                    idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
-                    if (
-                        idx >= len(response_c[response_input_ids[x]])
-                        or response_c[response_input_ids[x]][idx] - y > 10
-                    ):
-                        continue
-                    c += 1
-                    y = response_c[response_input_ids[x]][idx]
-                if c > res_c:
-                    res_c = c
-                    res_min = y - l + 1
-                    res = (l, y + 1)
-                elif c == res_c and y - l + 1 < res_min:
-                    res_min = y - l + 1
-                    res = (l, y + 1)
-
-            if res is None:
-                return response_word
-            # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
-            #     l -= 1
-            # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
-            #     l -= 1
-            return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
-
-        response_words = response.split(" ")
-
-        original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
-            "input_ids"
-        ]
-        N, M = len(response_words), len(original_input_ids)
-        recovered_response_words = []
-        l = 0
-        while l < N:
-            if response_words[l] not in compressed_prompt:
-                recovered_response_words.append(response_words[l])
-                l += 1
-                continue
-            r = l
-            while (
-                r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
-            ):
-                r += 1
-
-            match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
-            recovered_response_words.append(match_words)
-            l = r + 1
-        return " ".join(recovered_response_words)
-
-    def get_rank_results(
-        self,
-        context: list,
-        question: str,
-        rank_method: str,
-        condition_in_question: str,
-        context_tokens_length: list,
-    ):
-        def get_distance_bm25(corpus, query):
-            from rank_bm25 import BM25Okapi
-
-            tokenized_corpus = [doc.split(" ") for doc in corpus]
-            bm25 = BM25Okapi(tokenized_corpus)
-            tokenized_query = query.split(" ")
-            doc_scores = bm25.get_scores(tokenized_query)
-            idx = [(ii, 0) for ii in (-doc_scores).argsort()]
-            return idx
-
-        def get_distance_gzip(corpus, query):
-            def get_score(x, y):
-                cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
-                cxy = len(gzip.compress(f"{x} {y}".encode()))
-                return (cxy - min(cx, cy)) / max(cx, cy)
-
-            import gzip
-
-            doc_scores = [get_score(doc, query) for doc in corpus]
-            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
-            return idx
-
-        def get_distance_sentbert(corpus, query):
-            from sentence_transformers import SentenceTransformer, util
-
-            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
-                self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
-                self.retrieval_model_name = rank_method
-            doc_embeds = self.retrieval_model.encode(corpus)
-            query = self.retrieval_model.encode(query)
-            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
-            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
-            return idx
-
-        def get_distance_openai(corpus, query):
-            import openai
-            from sentence_transformers import util
-
-            openai.api_key = self.open_api_config.get("api_key", "")
-            openai.api_base = self.open_api_config.get(
-                "api_base", "https://api.openai.com/v1"
-            )
-            openai.api_type = self.open_api_config.get("api_type", "open_ai")
-            openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
-            engine = self.open_api_config.get("engine", "text-embedding-ada-002")
-
-            def get_embed(text):
-                return openai.Embedding.create(
-                    input=[text.replace("\n", " ")], engine=engine
-                )["data"][0]["embedding"]
-
-            doc_embeds = [get_embed(i) for i in corpus]
-            query = get_embed(query)
-            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
-            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
-            return idx
-
-        def get_distance_sentbert_bge(corpus, query):
-            from sentence_transformers import SentenceTransformer, util
-
-            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
-                self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
-                self.retrieval_model_name = rank_method
-            doc_embeds = self.retrieval_model.encode(
-                [i for i in corpus], normalize_embeddings=True
-            )
-            query = self.retrieval_model.encode(query, normalize_embeddings=True)
-            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
-            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
-            return idx
-
-        def get_distance_bge_ranker(corpus, query):
-            from transformers import AutoModelForSequenceClassification, AutoTokenizer
-
-            pairs = [[i, query] for i in corpus]
-            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
-                tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
-                model = (
-                    AutoModelForSequenceClassification.from_pretrained(
-                        "BAAI/bge-reranker-large"
-                    )
-                    .eval()
-                    .to(self.device)
-                )
-                self.retrieval_model = [tokenizer, model]
-                self.retrieval_model_name = rank_method
-            with torch.no_grad():
-                inputs = self.retrieval_model[0](
-                    pairs,
-                    padding=True,
-                    truncation=True,
-                    return_tensors="pt",
-                    max_length=512,
-                ).to(self.device)
-                scores = (
-                    self.retrieval_model[1](**inputs, return_dict=True)
-                    .logits.view(
-                        -1,
-                    )
-                    .float()
-                )
-            idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
-            return idx
-
-        def get_distance_bge_llmembedder(corpus, query):
-            from transformers import AutoModel, AutoTokenizer
-
-            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
-                tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
-                model = (
-                    AutoModel.from_pretrained("BAAI/llm-embedder")
-                    .eval()
-                    .to(self.device)
-                )
-                self.retrieval_model = [tokenizer, model]
-                self.retrieval_model_name = rank_method
-
-            instruction_qa_query = (
-                "Represent this query for retrieving relevant documents: "
-            )
-            instruction_qa_key = "Represent this document for retrieval: "
-            queries = [instruction_qa_query + query for _ in corpus]
-            keys = [instruction_qa_key + key for key in corpus]
-            with torch.no_grad():
-                query_inputs = self.retrieval_model[0](
-                    queries,
-                    padding=True,
-                    truncation=True,
-                    return_tensors="pt",
-                    max_length=512,
-                ).to(self.device)
-                key_inputs = self.retrieval_model[0](
-                    keys,
-                    padding=True,
-                    truncation=True,
-                    return_tensors="pt",
-                    max_length=512,
-                ).to(self.device)
-                query_outputs = self.retrieval_model[1](**query_inputs)
-                key_outputs = self.retrieval_model[1](**key_inputs)
-                # CLS pooling
-                query_embeddings = query_outputs.last_hidden_state[:, 0]
-                key_embeddings = key_outputs.last_hidden_state[:, 0]
-                # Normalize
-                query_embeddings = torch.nn.functional.normalize(
-                    query_embeddings, p=2, dim=1
-                )
-                key_embeddings = torch.nn.functional.normalize(
-                    key_embeddings, p=2, dim=1
-                )
-                similarity = query_embeddings @ key_embeddings.T
-            idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
-            return idx
-
-        def get_distance_jinza(corpus, query):
-            from numpy.linalg import norm
-
-            from transformers import AutoModel
-
-            def cos_sim(a, b):
-                return (a @ b.T) / (norm(a) * norm(b))
-
-            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
-                model = (
-                    AutoModel.from_pretrained(
-                        "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
-                    )
-                    .eval()
-                    .to(self.device)
-                )
-                self.retrieval_model = model
-                self.retrieval_model_name = rank_method
-
-            doc_embeds = self.retrieval_model.encode(corpus)
-            query = self.retrieval_model.encode(query)
-            doc_scores = cos_sim(doc_embeds, query)
-            idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
-            return idx
-
-        def get_distance_voyageai(corpus, query):
-            import voyageai
-            from sentence_transformers import util
-
-            voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
-
-            def get_embed(text):
-                return voyageai.get_embedding(text, model="voyage-01")
-
-            doc_embeds = [get_embed(i) for i in corpus]
-            query = get_embed(query)
-            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
-            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
-            return idx
-
-        def get_distance_cohere(corpus, query):
-            import cohere
-
-            api_key = self.open_api_config.get("cohere_api_key", "")
-            co = cohere.Client(api_key)
-            results = co.rerank(
-                model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
-            )
-            c_map = {jj: ii for ii, jj in enumerate(corpus)}
-            doc_rank = [c_map[ii.document["text"]] for ii in results]
-            idx = [(ii, 0) for ii in doc_rank]
-            return idx
-
-        def get_distance_longllmlingua(corpus, query):
-            context_ppl = [
-                self.get_condition_ppl(
-                    d,
-                    query
-                    + " We can get the answer to this question in the given documents.",
-                    condition_in_question,
-                )
-                - dl * 2 / 250 * 0
-                for d, dl in zip(corpus, context_tokens_length)
-            ]
-            sort_direct = -1 if condition_in_question == "none" else 1
-            ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
-            return ys
-
-        method = None
-        if rank_method == "bm25":
-            method = get_distance_bm25
-        elif rank_method == "gzip":
-            method = get_distance_gzip
-        elif rank_method == "sentbert":
-            method = get_distance_sentbert
-        elif rank_method == "openai":
-            method = get_distance_openai
-        elif rank_method in ["longllmlingua", "llmlingua"]:
-            method = get_distance_longllmlingua
-        elif rank_method == "bge":
-            method = get_distance_sentbert_bge
-        elif rank_method == "bge_reranker":
-            method = get_distance_bge_ranker
-        elif rank_method == "bge_llmembedder":
-            method = get_distance_bge_llmembedder
-        elif rank_method == "jinza":
-            method = get_distance_jinza
-        elif rank_method == "voyageai":
-            method = get_distance_voyageai
-        elif rank_method == "cohere":
-            method = get_distance_cohere
-        return method(context, question)
-
-    def segment_structured_context(
-        self,
-        context: List[str],
-        global_rate: float,
-    ):
-        new_context, context_segs, context_segs_rate, context_segs_compress = (
-            [],
-            [],
-            [],
-            [],
-        )
-        for text in context:
-            if not text.startswith("<llmlingua"):
-                text = "<llmlingua>" + text
-            if not text.endswith("</llmlingua>"):
-                text = text + "</llmlingua>"
-
-            # Regular expression to match <llmlingua, rate=x, compress=y>content</llmlingua>, allowing rate and compress in any order
-            pattern = r"<llmlingua\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>([^<]+)</llmlingua>"
-            matches = re.findall(pattern, text)
-
-            # Extracting segment contents
-            segments = [match[4] for match in matches]
-
-            # Extracting rate and compress, considering their possible positions
-            segs_rate = [
-                float(match[0]) if match[0] else (float(match[2]) if match[2] else None)
-                for match in matches
-            ]
-            segs_compress = [
-                (
-                    match[1] == "True"
-                    if match[1]
-                    else (match[3] == "True" if match[3] else None)
-                )
-                for match in matches
-            ]
-
-            segs_compress = [
-                compress if compress is not None else True for compress in segs_compress
-            ]
-            segs_rate = [
-                rate if rate else (global_rate if compress else 1.0)
-                for rate, compress in zip(segs_rate, segs_compress)
-            ]
-            assert (
-                len(segments) == len(segs_rate) == len(segs_compress)
-            ), "The number of segments, rates, and compress flags should be the same."
-            assert all(
-                seg_rate <= 1.0 for seg_rate in segs_rate
-            ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
-
-            new_context.append("".join(segments))
-            context_segs.append(segments)
-            context_segs_rate.append(segs_rate)
-            context_segs_compress.append(segs_compress)
-
-        return new_context, context_segs, context_segs_rate, context_segs_compress
-
-    def concate_segment_info(
-        self,
-        segment_info: List[List[tuple]],
-    ):
-        new_segment_info = []
-        for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info):
-            if (
-                new_segment_info
-                and new_segment_info[-1][1] == seg_ratio
-                and new_segment_info[-1][2] == seg_compress
-            ):
-                new_segment_info[-1] = (
-                    new_segment_info[-1][0] + seg_len,
-                    seg_ratio,
-                    seg_compress,
-                )
-            else:
-                new_segment_info.append((seg_len, seg_ratio, seg_compress))
-        return new_segment_info
-
-    def __get_context_prob(
-            self, 
-            context_list: list, 
-            token_to_word="mean",
-            force_tokens: List[str]=[],
-            token_map: dict={}, 
-            force_reserve_digit: bool=False, 
-        ):
-        chunk_list = []
-        for chunks in context_list:
-            for c in chunks:
-                chunk_list.append(c)
-
-        dataset = TokenClfDataset(
-            chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
-        )
-        dataloader = DataLoader(
-            dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
-        )
-
-        chunk_probs = []
-        chunk_words = []
-        with torch.no_grad():
-            for batch in dataloader:
-                ids = batch["ids"].to(self.device, dtype=torch.long)
-                mask = batch["mask"].to(self.device, dtype=torch.long) == 1
-
-                outputs = self.model(input_ids=ids, attention_mask=mask)
-                loss, logits = outputs.loss, outputs.logits
-                probs = F.softmax(logits, dim=-1)
-
-                for j in range(ids.shape[0]):
-                    _probs = probs[j, :, 1] 
-                    _ids = ids[j] 
-                    _mask = mask[j] 
-
-                    active_probs = torch.masked_select(_probs, _mask)
-                    active_ids = torch.masked_select(_ids, _mask)
-
-                    tokens = self.tokenizer.convert_ids_to_tokens(
-                        active_ids.squeeze().tolist()
-                    )
-                    token_probs = [prob for prob in active_probs.cpu().numpy()]
-
-                    (
-                        words,
-                        valid_token_probs,
-                        valid_token_probs_no_force,
-                    ) = self.__merge_token_to_word(
-                        tokens, 
-                        token_probs, 
-                        force_tokens=force_tokens,
-                        token_map=token_map, 
-                        force_reserve_digit=force_reserve_digit, 
-                    )
-                    word_probs_no_force = self.__token_prob_to_word_prob(
-                        valid_token_probs_no_force, convert_mode=token_to_word
-                    )
-
-                    if "xlm-roberta-large" in self.model_name:
-                        for i in range(len(words)):
-                            words[i] = words[i].lstrip("▁")
-                    chunk_words.append(words)
-                    chunk_probs.append(word_probs_no_force)
-
-        prev_idx = 0
-        context_probs = []
-        context_words = []
-        for chunk_list in context_list:
-            n_chunk = len(chunk_list)
-            context_probs.append([])
-            context_words.append([])
-            for i in range(n_chunk):
-                context_probs[-1].extend(chunk_probs[prev_idx + i])
-                context_words[-1].extend(chunk_words[prev_idx + i])
-            prev_idx = prev_idx + n_chunk
-        context_probs = [sum(probs) / len(probs) for probs in context_probs]
-        return context_probs, context_words
-
-    def __chunk_context(self, origin_text, chunk_end_tokens):
-        origin_list = []
-        origin_tokens = self.tokenizer.tokenize(origin_text)
-        n = len(origin_tokens)
-        st = 0
-        while st < n:
-            if st + self.max_seq_len > n - 1:
-                chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n])
-                origin_list.append(chunk)
-                break
-            else:
-                ed = st + self.max_seq_len
-                for j in range(0, ed - st):
-                    if origin_tokens[ed - j] in chunk_end_tokens:
-                        ed = ed - j
-                        break
-                chunk = self.tokenizer.convert_tokens_to_string(
-                    origin_tokens[st : ed + 1]
-                )
-                origin_list.append(chunk)
-                st = ed + 1
-        return origin_list
-
-    def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit):
-        words = []
-        word_probs = []
-        word_probs_no_force = []
-
-        for token, prob in zip(tokens, token_probs):
-            if token in self.special_tokens:
-                continue
-            # add a new word
-            elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map):
-                pure_token = get_pure_token(token, self.model_name)
-                prob_no_force = prob
-                if pure_token in force_tokens or pure_token in set(token_map.values()):
-                    prob=1.0
-                token = replace_added_token(token, token_map)
-                words.append(token)
-                word_probs.append(
-                    [
-                        1.0
-                        if force_reserve_digit
-                        and bool(re.search(r"\d", token))
-                        else prob
-                    ]
-                )
-                word_probs_no_force.append([prob_no_force])
-            # concatenate with previous token
-            else:
-                pure_token = get_pure_token(token, self.model_name)
-                words[-1] += pure_token
-                word_probs[-1].append(
-                    1.0
-                    if force_reserve_digit
-                    and bool(re.search(r"\d", token))
-                    else prob
-                )
-                word_probs_no_force[-1].append(prob_no_force)
-
-        return words, word_probs, word_probs_no_force
-
-    def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"):
-        if convert_mode == "mean":
-            word_probs = [sum(p) / len(p) for p in token_probs]
-        elif convert_mode == "first":
-            word_probs = [p[0] for p in token_probs]
-        else:
-            raise NotImplementedError()
-
-        return word_probs
-
-    def __compress(
-            self, 
-            context_list: list, 
-            reduce_rate: float=0.5, 
-            token_to_word: str="mean",
-            force_tokens: List[str]=[],
-            token_map: dict={}, 
-            force_reserve_digit: bool=False, 
-            drop_consecutive: bool=False,
-        ):
-        def split_string_to_words(input_string):
-            pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]'
-            result = re.findall(pattern, input_string)
-            return result
-        # print(force_tokens, token_map, force_reserve_digit, drop_consecutive)
-        if reduce_rate <= 0:
-            words, word_labels = [], []
-            for i in range(len(context_list)):
-                chunk_list = context_list[i]
-                chunk_words = []
-                chunk_word_labels = []
-                for j in range(len(chunk_list)):
-                    # replace to original token
-                    for ori_token, new_token in token_map.items():
-                        chunk_list[j] = chunk_list[j].replace(new_token, ori_token)
-                    ws = split_string_to_words(chunk_list[j])
-                    chunk_words.extend(ws)
-                    chunk_word_labels.extend([1 for _ in range(len(ws))])
-                context_list[i] = "".join(chunk_list)
-                words.append(chunk_words)
-                word_labels.append(chunk_word_labels)
-            return context_list, words, word_labels
-
-        chunk_list = []
-        for chunks in context_list:
-            for c in chunks:
-                chunk_list.append(c)
-
-        dataset = TokenClfDataset(
-            chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
-        )
-        dataloader = DataLoader(
-            dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
-        )
-
-        compressed_chunk_list = []
-        word_list = []
-        word_label_list = []
-        with torch.no_grad():
-            for batch in dataloader:
-                ids = batch["ids"].to(self.device, dtype=torch.long)
-                mask = batch["mask"].to(self.device, dtype=torch.long) == 1
-
-                outputs = self.model(input_ids=ids, attention_mask=mask)
-                loss, logits = outputs.loss, outputs.logits
-                probs = F.softmax(logits, dim=-1)
-
-                for j in range(ids.shape[0]):
-                    chunk_probs = probs[j, :, 1]
-                    chunk_ids = ids[j]
-                    chunk_mask = mask[j]
-
-                    active_probs = torch.masked_select(chunk_probs, chunk_mask)
-                    active_ids = torch.masked_select(chunk_ids, chunk_mask)
-
-                    tokens = self.tokenizer.convert_ids_to_tokens(
-                        active_ids.squeeze().tolist()
-                    )
-                    token_probs = [prob for prob in active_probs.cpu().numpy()]
-
-                    words, valid_token_probs, _ = self.__merge_token_to_word(
-                        tokens=tokens, 
-                        token_probs=token_probs, 
-                        force_tokens=force_tokens, 
-                        token_map=token_map,
-                        force_reserve_digit=force_reserve_digit,
-                    )
-                    word_probs = self.__token_prob_to_word_prob(
-                        valid_token_probs, convert_mode=token_to_word
-                    )
-
-                    if drop_consecutive:
-                        threshold = np.percentile(word_probs, int(100 * reduce_rate))
-                        is_token_between = False
-                        prev = None
-                        for i, (word, word_prob) in enumerate(zip(words, word_probs)):
-                            if word in force_tokens:
-                                if is_token_between:
-                                    is_token_between = False
-                                elif not is_token_between and word == prev:
-                                    word_probs[i] = 0.0
-                                prev = word
-                            else:
-                                is_token_between |= word_prob > threshold
-
-                    # calculate compression ratio w.r.t. gpt-4 tokenizer
-                    new_token_probs = []
-                    for word, word_prob in zip(words, word_probs):
-                        num_token = len(self.oai_tokenizer.encode(word))
-                        new_token_probs.extend([word_prob for _ in range(num_token)])
-                    threshold = np.percentile(
-                        new_token_probs, int(100 * reduce_rate + 1)
-                    )
-
-                    keep_words = []
-                    word_labels = []
-                    assert len(words) == len(word_probs)
-                    for word, word_porb in zip(words, word_probs):
-                        if word_porb > threshold:
-                            if (
-                                drop_consecutive
-                                and word in force_tokens
-                                and len(keep_words) > 0
-                                and keep_words[-1] == word
-                            ):
-                                word_labels.append(0)
-                            else:
-                                keep_words.append(word)
-                                word_labels.append(1)
-                        else:
-                            word_labels.append(0)
-                    keep_str = self.tokenizer.convert_tokens_to_string(keep_words)
-                    if "xlm-roberta-large" in self.model_name:
-                        for i in range(len(words)):
-                            words[i] = words[i].lstrip("▁")
-
-                    compressed_chunk_list.append(keep_str)
-                    word_list.append(words[:])
-                    word_label_list.append(word_labels[:])
-
-        compressed_context_list = []
-        original_word_list = []
-        original_word_label_list = []
-        prev_idx = 0
-        for chunk_list in context_list:
-            n_chunk = len(chunk_list)
-            compressed_context_list.append(
-                "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk])
-            )
-            original_word_list.append([])
-            original_word_label_list.append([])
-            for i in range(n_chunk):
-                original_word_list[-1].extend(word_list[prev_idx + i])
-                original_word_label_list[-1].extend(word_label_list[prev_idx + i])
-            prev_idx = prev_idx + n_chunk
-
-        return compressed_context_list, original_word_list, original_word_label_list