Spaces:
Sleeping
Sleeping
"""Module contains the class to create a number prompt.""" | |
import re | |
from decimal import Decimal | |
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast | |
from prompt_toolkit.application.application import Application | |
from prompt_toolkit.buffer import Buffer | |
from prompt_toolkit.filters.base import Condition | |
from prompt_toolkit.filters.cli import IsDone | |
from prompt_toolkit.keys import Keys | |
from prompt_toolkit.layout.containers import ( | |
ConditionalContainer, | |
HorizontalAlign, | |
HSplit, | |
VSplit, | |
Window, | |
) | |
from prompt_toolkit.layout.controls import ( | |
BufferControl, | |
DummyControl, | |
FormattedTextControl, | |
) | |
from prompt_toolkit.layout.dimension import Dimension, LayoutDimension | |
from prompt_toolkit.layout.layout import Layout | |
from prompt_toolkit.lexers.base import SimpleLexer | |
from prompt_toolkit.validation import ValidationError | |
from InquirerPy.base.complex import BaseComplexPrompt, FakeDocument | |
from InquirerPy.containers.instruction import InstructionWindow | |
from InquirerPy.containers.validation import ValidationWindow | |
from InquirerPy.enum import INQUIRERPY_QMARK_SEQUENCE | |
from InquirerPy.exceptions import InvalidArgument | |
from InquirerPy.utils import ( | |
InquirerPyDefault, | |
InquirerPyKeybindings, | |
InquirerPyMessage, | |
InquirerPySessionResult, | |
InquirerPyStyle, | |
InquirerPyValidate, | |
) | |
if TYPE_CHECKING: | |
from prompt_toolkit.key_binding.key_processor import KeyPressEvent | |
__all__ = ["NumberPrompt"] | |
class NumberPrompt(BaseComplexPrompt): | |
"""Create a input prompts that only takes number as input. | |
A wrapper class around :class:`~prompt_toolkit.application.Application`. | |
Args: | |
message: The question to ask the user. | |
Refer to :ref:`pages/dynamic:message` documentation for more details. | |
style: An :class:`InquirerPyStyle` instance. | |
Refer to :ref:`Style <pages/style:Alternate Syntax>` documentation for more details. | |
vi_mode: Use vim keybinding for the prompt. | |
Refer to :ref:`pages/kb:Keybindings` documentation for more details. | |
default: Set the default value of the prompt. | |
You can enter either the floating value or integer value as the default. | |
Refer to :ref:`pages/dynamic:default` documentation for more details. | |
float_allowed: Allow decimal input. This will change the prompt to have 2 input buffer, one for the | |
whole value and one for the integral value. | |
min_allowed: Set the minimum value of the prompt. When the input value goes below this value, it | |
will automatically reset to this value. | |
max_allowed: Set the maximum value of the prompt. When the inptu value goes above this value, it | |
will automatically reset to this value. | |
qmark: Question mark symbol. Custom symbol that will be displayed infront of the question before its answered. | |
amark: Answer mark symbol. Custom symbol that will be displayed infront of the question after its answered. | |
decimal_symbol: Decimal point symbol. Custom symbol to display as the decimal point. | |
replace_mode: Start each input buffer in replace mode if default value is 0. | |
When typing, it will replace the 0 with the new value. The replace mode will be disabled once the value | |
is changed. | |
instruction: Short instruction to display next to the question. | |
long_instruction: Long instructions to display at the bottom of the prompt. | |
validate: Add validation to user input. | |
Refer to :ref:`pages/validator:Validator` documentation for more details. | |
invalid_message: Error message to display when user input is invalid. | |
Refer to :ref:`pages/validator:Validator` documentation for more details. | |
invalid_message: Error message to display when user input is invalid. | |
Refer to :ref:`pages/validator:Validator` documentation for more details. | |
transformer: A function which performs additional transformation on the value that gets printed to the terminal. | |
Different than `filter` parameter, this is only visual effect and won’t affect the actual value returned by :meth:`~InquirerPy.base.simple.BaseSimplePrompt.execute`. | |
Refer to :ref:`pages/dynamic:transformer` documentation for more details. | |
filter: A function which performs additional transformation on the result. | |
This affects the actual value returned by :meth:`~InquirerPy.base.simple.BaseSimplePrompt.execute`. | |
Refer to :ref:`pages/dynamic:filter` documentation for more details. | |
keybindings: Customise the builtin keybindings. | |
Refer to :ref:`pages/kb:Keybindings` for more details. | |
wrap_lines: Soft wrap question lines when question exceeds the terminal width. | |
raise_keyboard_interrupt: Raise the :class:`KeyboardInterrupt` exception when `ctrl-c` is pressed. If false, the result | |
will be `None` and the question is skiped. | |
mandatory: Indicate if the prompt is mandatory. If True, then the question cannot be skipped. | |
mandatory_message: Error message to show when user attempts to skip mandatory prompt. | |
session_result: Used internally for :ref:`index:Classic Syntax (PyInquirer)`. | |
Examples: | |
>>> from InquirerPy import inquirer | |
>>> result = inquirer.number(message="Enter number:").execute() | |
>>> print(result) | |
0 | |
""" | |
def __init__( | |
self, | |
message: InquirerPyMessage, | |
style: Optional[InquirerPyStyle] = None, | |
vi_mode: bool = False, | |
default: InquirerPyDefault = 0, | |
float_allowed: bool = False, | |
max_allowed: Optional[Union[int, float]] = None, | |
min_allowed: Optional[Union[int, float]] = None, | |
decimal_symbol: str = ". ", | |
replace_mode: bool = False, | |
qmark: str = INQUIRERPY_QMARK_SEQUENCE, | |
amark: str = "?", | |
instruction: str = "", | |
long_instruction: str = "", | |
validate: Optional[InquirerPyValidate] = None, | |
invalid_message: str = "Invalid input", | |
transformer: Optional[Callable[[str], Any]] = None, | |
filter: Optional[Callable[[str], Any]] = None, | |
keybindings: Optional[InquirerPyKeybindings] = None, | |
wrap_lines: bool = True, | |
raise_keyboard_interrupt: bool = True, | |
mandatory: bool = True, | |
mandatory_message: str = "Mandatory prompt", | |
session_result: Optional[InquirerPySessionResult] = None, | |
) -> None: | |
super().__init__( | |
message=message, | |
style=style, | |
vi_mode=vi_mode, | |
qmark=qmark, | |
amark=amark, | |
transformer=transformer, | |
filter=filter, | |
invalid_message=invalid_message, | |
validate=validate, | |
instruction=instruction, | |
long_instruction=long_instruction, | |
wrap_lines=wrap_lines, | |
raise_keyboard_interrupt=raise_keyboard_interrupt, | |
mandatory=mandatory, | |
mandatory_message=mandatory_message, | |
session_result=session_result, | |
) | |
self._float = float_allowed | |
self._is_float = Condition(lambda: self._float) | |
self._max = max_allowed | |
self._min = min_allowed | |
self._value_error_message = "Remove any non-integer value" | |
self._decimal_symbol = decimal_symbol | |
self._whole_replace = False | |
self._integral_replace = False | |
self._replace_mode = replace_mode | |
self._leading_zero_pattern = re.compile(r"^(0*)[0-9]+.*") | |
self._sn_pattern = re.compile(r"^.*E-.*") | |
self._no_default = False | |
if default is None: | |
default = 0 | |
self._no_default = True | |
if isinstance(default, Callable): | |
default = cast(Callable, default)(session_result) | |
if self._float: | |
default = Decimal(str(float(cast(int, default)))) | |
if self._float: | |
if not isinstance(default, float) and not isinstance(default, Decimal): | |
raise InvalidArgument( | |
f"{type(self).__name__} argument 'default' should return type of float or Decimal" | |
) | |
elif not isinstance(default, int): | |
raise InvalidArgument( | |
f"{type(self).__name__} argument 'default' should return type of int" | |
) | |
self._default = default | |
if keybindings is None: | |
keybindings = {} | |
self.kb_maps = { | |
"down": [ | |
{"key": "down"}, | |
{"key": "c-n", "filter": ~self._is_vim_edit}, | |
{"key": "j", "filter": self._is_vim_edit}, | |
], | |
"up": [ | |
{"key": "up"}, | |
{"key": "c-p", "filter": ~self._is_vim_edit}, | |
{"key": "k", "filter": self._is_vim_edit}, | |
], | |
"left": [ | |
{"key": "left"}, | |
{"key": "c-b", "filter": ~self._is_vim_edit}, | |
{"key": "h", "filter": self._is_vim_edit}, | |
], | |
"right": [ | |
{"key": "right"}, | |
{"key": "c-f", "filter": ~self._is_vim_edit}, | |
{"key": "l", "filter": self._is_vim_edit}, | |
], | |
"dot": [{"key": "."}], | |
"focus": [{"key": Keys.Tab}, {"key": "s-tab"}], | |
"input": [{"key": str(i)} for i in range(10)], | |
"negative_toggle": [{"key": "-"}], | |
**keybindings, | |
} | |
self.kb_func_lookup = { | |
"down": [{"func": self._handle_down}], | |
"up": [{"func": self._handle_up}], | |
"left": [{"func": self._handle_left}], | |
"right": [{"func": self._handle_right}], | |
"focus": [{"func": self._handle_focus}], | |
"input": [{"func": self._handle_input}], | |
"negative_toggle": [{"func": self._handle_negative_toggle}], | |
"dot": [{"func": self._handle_dot}], | |
} | |
def _(_): | |
pass | |
self._whole_width = 1 | |
self._whole_buffer = Buffer( | |
on_text_changed=self._on_whole_text_change, | |
on_cursor_position_changed=self._on_cursor_position_change, | |
) | |
self._integral_width = 1 | |
self._integral_buffer = Buffer( | |
on_text_changed=self._on_integral_text_change, | |
on_cursor_position_changed=self._on_cursor_position_change, | |
) | |
self._whole_window = Window( | |
height=LayoutDimension.exact(1) if not self._wrap_lines else None, | |
content=BufferControl( | |
buffer=self._whole_buffer, | |
lexer=SimpleLexer("class:input"), | |
), | |
width=lambda: Dimension( | |
min=self._whole_width, | |
max=self._whole_width, | |
preferred=self._whole_width, | |
), | |
dont_extend_width=True, | |
) | |
self._integral_window = Window( | |
height=LayoutDimension.exact(1) if not self._wrap_lines else None, | |
content=BufferControl( | |
buffer=self._integral_buffer, | |
lexer=SimpleLexer("class:input"), | |
), | |
width=lambda: Dimension( | |
min=self._integral_width, | |
max=self._integral_width, | |
preferred=self._integral_width, | |
), | |
) | |
self._layout = Layout( | |
HSplit( | |
[ | |
VSplit( | |
[ | |
Window( | |
height=LayoutDimension.exact(1) | |
if not self._wrap_lines | |
else None, | |
content=FormattedTextControl(self._get_prompt_message), | |
wrap_lines=self._wrap_lines, | |
dont_extend_height=True, | |
dont_extend_width=True, | |
), | |
ConditionalContainer(self._whole_window, filter=~IsDone()), | |
ConditionalContainer( | |
Window( | |
height=LayoutDimension.exact(1) | |
if not self._wrap_lines | |
else None, | |
content=FormattedTextControl( | |
[("", self._decimal_symbol)] | |
), | |
wrap_lines=self._wrap_lines, | |
dont_extend_height=True, | |
dont_extend_width=True, | |
), | |
filter=self._is_float & ~IsDone(), | |
), | |
ConditionalContainer( | |
self._integral_window, filter=self._is_float & ~IsDone() | |
), | |
], | |
align=HorizontalAlign.LEFT, | |
), | |
ConditionalContainer( | |
Window(content=DummyControl()), | |
filter=~IsDone() & self._is_displaying_long_instruction, | |
), | |
ValidationWindow( | |
invalid_message=self._get_error_message, | |
filter=self._is_invalid & ~IsDone(), | |
wrap_lines=self._wrap_lines, | |
), | |
InstructionWindow( | |
message=self._long_instruction, | |
filter=~IsDone() & self._is_displaying_long_instruction, | |
wrap_lines=self._wrap_lines, | |
), | |
] | |
), | |
) | |
self.focus = self._whole_window | |
self._application = Application( | |
layout=self._layout, | |
style=self._style, | |
key_bindings=self._kb, | |
after_render=self._after_render, | |
editing_mode=self._editing_mode, | |
) | |
def _fix_sn(self, value: str) -> Tuple[str, str]: | |
"""Fix sciencetific notation format. | |
Args: | |
value: Value to fix. | |
Returns: | |
A tuple of whole buffer text and integral buffer text. | |
""" | |
left, right = value.split("E-") | |
whole_buffer_text = "0" | |
integral_buffer_text = f"{(int(right) - 1) * '0'}{left.replace('.', '')}" | |
return whole_buffer_text, integral_buffer_text | |
def _on_rendered(self, _) -> None: | |
"""Additional processing to adjust buffer content after render.""" | |
if self._no_default: | |
return | |
if not self._float: | |
self._whole_buffer.text = str(self._default) | |
self._integral_buffer.text = "0" | |
else: | |
if self._sn_pattern.match(str(self._default)) is None: | |
whole_buffer_text, integral_buffer_text = str(self._default).split(".") | |
else: | |
whole_buffer_text, integral_buffer_text = self._fix_sn( | |
str(self._default) | |
) | |
self._integral_buffer.text = integral_buffer_text | |
self._whole_buffer.text = whole_buffer_text | |
self._whole_buffer.cursor_position = len(self._whole_buffer.text) | |
self._integral_buffer.cursor_position = len(self._integral_buffer.text) | |
if self._replace_mode: | |
# check to start replace mode if applicable | |
if self._whole_buffer.text == "0": | |
self._whole_replace = True | |
self._whole_buffer.cursor_position = 0 | |
if self._integral_buffer.text == "0": | |
self._integral_replace = True | |
self._integral_buffer.cursor_position = 0 | |
def _handle_number(self, increment: bool) -> None: | |
"""Handle number increment and decrement. | |
Additional processing to handle leading zeros in integral buffer | |
as well as SN notation. | |
Args: | |
increment: Indicate if the operation should increment or decrement. | |
""" | |
if self.buffer_replace: | |
self.buffer_replace = False | |
self.focus_buffer.cursor_position += 1 | |
try: | |
leading_zeros = "" | |
if self.focus_buffer == self._integral_buffer: | |
zeros = self._leading_zero_pattern.match(self._integral_buffer.text) | |
if zeros is not None: | |
leading_zeros = zeros.group(1) | |
current_text_len = len(self.focus_buffer.text) | |
if not self.focus_buffer.text: | |
next_text = "0" | |
next_text_len = 1 | |
else: | |
if not increment: | |
if ( | |
self.focus_buffer == self._integral_buffer | |
and int(self.focus_buffer.text) == 0 | |
): | |
return | |
next_text = leading_zeros + str(int(self.focus_buffer.text) - 1) | |
else: | |
next_text = leading_zeros + str(int(self.focus_buffer.text) + 1) | |
next_text_len = len(next_text) | |
desired_position = ( | |
self.focus_buffer.cursor_position + next_text_len - current_text_len | |
) | |
self.focus_buffer.cursor_position = desired_position | |
self.focus_buffer.text = next_text | |
if self.focus_buffer.cursor_position != desired_position: | |
self.focus_buffer.cursor_position = desired_position | |
except ValueError: | |
self._set_error(message=self._value_error_message) | |
def _handle_down(self, _) -> None: | |
"""Handle down key press.""" | |
self._handle_number(increment=False) | |
def _handle_up(self, _) -> None: | |
"""Handle up key press.""" | |
self._handle_number(increment=True) | |
def _handle_left(self, _) -> None: | |
"""Handle left key press. | |
Move to the left by one cursor position and focus the whole window | |
if applicable. | |
""" | |
self.buffer_replace = False | |
if ( | |
self.focus == self._integral_window | |
and self.focus_buffer.cursor_position == 0 | |
): | |
self.focus = self._whole_window | |
else: | |
self.focus_buffer.cursor_position -= 1 | |
def _handle_right(self, _) -> None: | |
"""Handle right key press. | |
Move to the right by one cursor position and focus the integral window | |
if applicable. | |
""" | |
self.buffer_replace = False | |
if ( | |
self.focus == self._whole_window | |
and self.focus_buffer.cursor_position == len(self.focus_buffer.text) | |
and self._float | |
): | |
self.focus = self._integral_window | |
else: | |
self.focus_buffer.cursor_position += 1 | |
def _handle_enter(self, event: "KeyPressEvent") -> None: | |
"""Handle enter event and answer/close the prompt.""" | |
if not self._float and not self._whole_buffer.text: | |
result = "" | |
elif ( | |
self._float | |
and not self._whole_buffer.text | |
and not self._integral_buffer.text | |
): | |
result = "" | |
else: | |
result = str(self.value) | |
try: | |
fake_document = FakeDocument(result) | |
self._validator.validate(fake_document) # type: ignore | |
except ValidationError as e: | |
self._set_error(str(e)) | |
else: | |
self.status["answered"] = True | |
self.status["result"] = result | |
event.app.exit(result=result) | |
def _handle_dot(self, _) -> None: | |
"""Focus the integral window if `float_allowed`.""" | |
self._handle_focus(_, self._integral_window) | |
def _handle_focus(self, _, window: Optional[Window] = None) -> None: | |
"""Focus either the integral window or whole window.""" | |
if not self._float: | |
return | |
if window is not None: | |
self.focus = window | |
return | |
if self.focus == self._whole_window: | |
self.focus = self._integral_window | |
else: | |
self.focus = self._whole_window | |
def _handle_input(self, event: "KeyPressEvent") -> None: | |
"""Handle user input of numbers. | |
Buffer will start as replace mode if the value is zero, once | |
cursor is moved or content is changed, disable replace mode. | |
""" | |
if self.buffer_replace: | |
self.buffer_replace = False | |
self.focus_buffer.text = event.key_sequence[0].data | |
self.focus_buffer.cursor_position += 1 | |
else: | |
self.focus_buffer.insert_text(event.key_sequence[0].data) | |
def _handle_negative_toggle(self, _) -> None: | |
"""Toggle negativity of the prompt value. | |
Force the `-` sign at the start. | |
""" | |
if self._whole_buffer.text == "-": | |
self._whole_buffer.text = "0" | |
return | |
if self._whole_buffer.text.startswith("-"): | |
move_cursor = self._whole_buffer.cursor_position < len( | |
self._whole_buffer.text | |
) | |
self._whole_buffer.text = self._whole_buffer.text[1:] | |
if move_cursor: | |
self._whole_buffer.cursor_position -= 1 | |
else: | |
move_cursor = self._whole_buffer.cursor_position != 0 | |
self._whole_buffer.text = f"-{self._whole_buffer.text}" | |
if move_cursor: | |
self._whole_buffer.cursor_position += 1 | |
def _on_whole_text_change(self, buffer: Buffer) -> None: | |
"""Handle event of text changes in buffer.""" | |
self._whole_width = len(buffer.text) + 1 | |
self._on_text_change(buffer) | |
def _on_integral_text_change(self, buffer: Buffer) -> None: | |
"""Handle event of text changes in buffer.""" | |
self._integral_width = len(buffer.text) + 1 | |
self._on_text_change(buffer) | |
def _on_text_change(self, buffer: Buffer) -> None: | |
"""Disable replace mode and fix cursor position on text changes.""" | |
self.buffer_replace = False | |
if buffer.text and buffer.text != "-": | |
self.value = self.value | |
if buffer.text.startswith("-") and buffer.cursor_position == 0: | |
buffer.cursor_position = 1 | |
def _on_cursor_position_change(self, buffer: Buffer) -> None: | |
"""Fix cursor position on cursor movement.""" | |
if self.focus_buffer.text.startswith("-") and buffer.cursor_position == 0: | |
buffer.cursor_position = 1 | |
def buffer_replace(self) -> bool: | |
"""bool: Current buffer replace mode.""" | |
if self.focus_buffer == self._whole_buffer: | |
return self._whole_replace | |
else: | |
return self._integral_replace | |
def buffer_replace(self, value) -> None: | |
if self.focus_buffer == self._whole_buffer: | |
self._whole_replace = value | |
else: | |
self._integral_replace = value | |
def focus_buffer(self) -> Buffer: | |
"""Buffer: Current editable buffer.""" | |
if self.focus == self._whole_window: | |
return self._whole_buffer | |
else: | |
return self._integral_buffer | |
def focus(self) -> Window: | |
"""Window: Current focused window.""" | |
return self._focus | |
def focus(self, value: Window) -> None: | |
self._focus = value | |
self._layout.focus(self._focus) | |
def value(self) -> Union[int, float, Decimal]: | |
"""Union[int, float]: The actual value of the prompt, combining and transforming all input buffer values.""" | |
try: | |
if not self._float: | |
return int(self._whole_buffer.text) | |
else: | |
return Decimal( | |
f"{self._whole_buffer.text}.{self._integral_buffer.text if self._integral_buffer.text else 0}" | |
) | |
except ValueError: | |
self._set_error(self._value_error_message) | |
return self._default | |
def value(self, value: Union[int, float, Decimal]) -> None: | |
if self._min is not None: | |
value = max( | |
value, self._min if not self._float else Decimal(str(self._min)) | |
) | |
if self._max is not None: | |
value = min( | |
value, self._max if not self._float else Decimal(str(self._max)) | |
) | |
if not self._float: | |
self._whole_buffer.text = str(value) | |
else: | |
if self._sn_pattern.match(str(value)) is None: | |
whole_buffer_text, integral_buffer_text = str(value).split(".") | |
else: | |
whole_buffer_text, integral_buffer_text = self._fix_sn(str(value)) | |
if self._whole_buffer.text: | |
self._whole_buffer.text = whole_buffer_text | |
if self._integral_buffer.text: | |
self._integral_buffer.text = integral_buffer_text | |