From d94a1990fe02abc9a53c4f4f3ad740903678918d Mon Sep 17 00:00:00 2001 From: Clif Houck Date: Tue, 5 May 2026 17:21:57 -0500 Subject: [PATCH 1/3] Ensure kernel_append_params are valid kernel parameters By defining a kernel command line grammar and attemping to parse kernel_append_params. A successful parse indicates the input contained in kernel_append_params are valid kernel parameters. Unsuccessful parsing will raise and be rejected. This parsing can be disabled through a new conductor configuration option: disable_kernel_parameter_parsing which is False by default. Basic kernel parameter sanitization (ie filtering newlines) is always applied in kernel_append_params since they are never valid for inclusion. Future patches should extend kernel parameter parsing to all areas of Ironic's code base in order to guarantee valid kernel parameters being passed along. NOTE: This patch is back-ported from stable/2026.{1,2} and slightly weakens the kernel command line grammar by not including init arguments. Lark's stand-alone LALR(1) parser can't handle the ambiguity introduced. This commit addresses CVE-2026-46447. Closes-Bug: 2150624 Change-Id: I31ee960f6f055e39dd248f54ac853e21838632df Signed-off-by: Clif Houck --- .pre-commit-config.yaml | 9 +- LICENSE | 2 + .../kernel_parameter_parser/__init__.py | 0 .../common/kernel_parameter_parser/grammar.g | 44 + .../kernel_parameter_parser.py | 3572 +++++++++++++++++ ironic/common/kernel_parameters.py | 159 + ironic/common/pxe_utils.py | 2 + ironic/conf/conductor.py | 12 + ironic/drivers/utils.py | 22 +- .../unit/common/test_kernel_parameters.py | 230 ++ ironic/tests/unit/drivers/test_utils.py | 97 + ...kernel-append-params-8b2953a9d903d0f6.yaml | 14 + 12 files changed, 4161 insertions(+), 2 deletions(-) create mode 100644 ironic/common/kernel_parameter_parser/__init__.py create mode 100644 ironic/common/kernel_parameter_parser/grammar.g create mode 100644 ironic/common/kernel_parameter_parser/kernel_parameter_parser.py create mode 100644 ironic/common/kernel_parameters.py create mode 100644 ironic/tests/unit/common/test_kernel_parameters.py create mode 100644 releasenotes/notes/sanitize-kernel-append-params-8b2953a9d903d0f6.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc29f02e6a..17e324b9d1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,13 @@ repos: hooks: - id: hacking additional_dependencies: [] - exclude: '^(doc|releasenotes|tools)/.*$' + exclude: | + (?x)^( + doc| + releasenotes| + tools| + ironic/common/kernel_parameter_parser + )/.*$ - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: @@ -85,6 +91,7 @@ repos: hooks: - id: ruff args: ['--fix', '--unsafe-fixes'] + exclude: '^ironic/common/kernel_parameter_parser/.*$' - repo: local hooks: - id: check-releasenotes diff --git a/LICENSE b/LICENSE index 68c771a099..2902e77a73 100644 --- a/LICENSE +++ b/LICENSE @@ -174,3 +174,5 @@ incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. +ironic/common/kernel_parameter_parser/kernel_parameter_parser.py is governed +by the MPL v2.0 license. See that file for more information. diff --git a/ironic/common/kernel_parameter_parser/__init__.py b/ironic/common/kernel_parameter_parser/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ironic/common/kernel_parameter_parser/grammar.g b/ironic/common/kernel_parameter_parser/grammar.g new file mode 100644 index 0000000000..320fc25de1 --- /dev/null +++ b/ironic/common/kernel_parameter_parser/grammar.g @@ -0,0 +1,44 @@ +// NOTE(clif): This grammar is used by lark to generate the parser found in +// kernel_parameter_parser.py +// +// The following assumes your current working directory is the same as this +// grammar file and the lark python library is available. Check the existing +// generated parser to find the lark +// version used. +// +// Use this command to regenerate the parser: +// +// $ python -m lark.tools.standalone grammar.g > kernel_parameter_parser.py +// +// The generated file (kernel_parameter_parser.py) will note at the beginning +// of the file which version of lark was used to generate it. Additionally +// note that lark requires python >= 3.8. Which means the stand-alone parser +// *should* be fine for any recent-ish version of Ironic. +// +// NOTE(clif): The generated parser has some calls to pickle which trip +// Bandit's B301 test. If you regenerate the parser you will need to mark +// Those lines as # noqa B301 +// If for some reason we do end up using lark's cache or save/load() +// functionality then we'll have to revisit this decision. + +?start: kernel_command_line + +kernel_command_line: parameter_list + +parameter_list: parameter?(" "+ parameter)* + +parameter: key + | key_value_pair + +key_value_pair: key"="value + +key: /[A-Za-z0-9_\-\.]+/ + +value: bare_value + | quoted_value + +quoted_value: "\"" value_with_spaces "\"" + +bare_value: /[\!\#-\\.0-9:-\@A-Za-z\[-~]+/ + +value_with_spaces: /[\!\#-\\.0-9:-\@A-Za-z\[-~ ]+/ diff --git a/ironic/common/kernel_parameter_parser/kernel_parameter_parser.py b/ironic/common/kernel_parameter_parser/kernel_parameter_parser.py new file mode 100644 index 0000000000..60aa113a48 --- /dev/null +++ b/ironic/common/kernel_parameter_parser/kernel_parameter_parser.py @@ -0,0 +1,3572 @@ +# The file was automatically generated by Lark v1.3.1 +__version__ = "1.3.1" + +# +# +# Lark Stand-alone Generator Tool +# ---------------------------------- +# Generates a stand-alone LALR(1) parser +# +# Git: https://github.com/erezsh/lark +# Author: Erez Shinan (erezshin@gmail.com) +# +# +# >>> LICENSE +# +# This tool and its generated code use a separate license from Lark, +# and are subject to the terms of the Mozilla Public License, v. 2.0. +# If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +# +# If you wish to purchase a commercial license for this tool and its +# generated code, you may contact me via email or otherwise. +# +# If MPL2 is incompatible with your free or open-source project, +# contact me and we'll work it out. +# +# + +from copy import deepcopy +from abc import ABC, abstractmethod +from types import ModuleType +from typing import ( + TypeVar, Generic, Type, Tuple, List, Dict, Iterator, Collection, Callable, Optional, FrozenSet, Any, + Union, Iterable, IO, TYPE_CHECKING, overload, Sequence, + Pattern as REPattern, ClassVar, Set, Mapping +) + + +class LarkError(Exception): + pass + + +class ConfigurationError(LarkError, ValueError): + pass + + +def assert_config(value, options: Collection, msg='Got %r, expected one of %s'): + if value not in options: + raise ConfigurationError(msg % (value, options)) + + +class GrammarError(LarkError): + pass + + +class ParseError(LarkError): + pass + + +class LexError(LarkError): + pass + +T = TypeVar('T') + +class UnexpectedInput(LarkError): + #-- + line: int + column: int + pos_in_stream = None + state: Any + _terminals_by_name = None + interactive_parser: 'InteractiveParser' + + def get_context(self, text: str, span: int=40) -> str: + #-- + pos = self.pos_in_stream or 0 + start = max(pos - span, 0) + end = pos + span + if not isinstance(text, bytes): + before = text[start:pos].rsplit('\n', 1)[-1] + after = text[pos:end].split('\n', 1)[0] + return before + after + '\n' + ' ' * len(before.expandtabs()) + '^\n' + else: + before = text[start:pos].rsplit(b'\n', 1)[-1] + after = text[pos:end].split(b'\n', 1)[0] + return (before + after + b'\n' + b' ' * len(before.expandtabs()) + b'^\n').decode("ascii", "backslashreplace") + + def match_examples(self, parse_fn: 'Callable[[str], Tree]', + examples: Union[Mapping[T, Iterable[str]], Iterable[Tuple[T, Iterable[str]]]], + token_type_match_fallback: bool=False, + use_accepts: bool=True + ) -> Optional[T]: + #-- + assert self.state is not None, "Not supported for this exception" + + if isinstance(examples, Mapping): + examples = examples.items() + + candidate = (None, False) + for i, (label, example) in enumerate(examples): + assert not isinstance(example, str), "Expecting a list" + + for j, malformed in enumerate(example): + try: + parse_fn(malformed) + except UnexpectedInput as ut: + if ut.state == self.state: + if ( + use_accepts + and isinstance(self, UnexpectedToken) + and isinstance(ut, UnexpectedToken) + and ut.accepts != self.accepts + ): + logger.debug("Different accepts with same state[%d]: %s != %s at example [%s][%s]" % + (self.state, self.accepts, ut.accepts, i, j)) + continue + if ( + isinstance(self, (UnexpectedToken, UnexpectedEOF)) + and isinstance(ut, (UnexpectedToken, UnexpectedEOF)) + ): + if ut.token == self.token: ## + + logger.debug("Exact Match at example [%s][%s]" % (i, j)) + return label + + if token_type_match_fallback: + ## + + if (ut.token.type == self.token.type) and not candidate[-1]: + logger.debug("Token Type Fallback at example [%s][%s]" % (i, j)) + candidate = label, True + + if candidate[0] is None: + logger.debug("Same State match at example [%s][%s]" % (i, j)) + candidate = label, False + + return candidate[0] + + def _format_expected(self, expected): + if self._terminals_by_name: + d = self._terminals_by_name + expected = [d[t_name].user_repr() if t_name in d else t_name for t_name in expected] + return "Expected one of: \n\t* %s\n" % '\n\t* '.join(expected) + + +class UnexpectedEOF(ParseError, UnexpectedInput): + #-- + expected: 'List[Token]' + + def __init__(self, expected, state=None, terminals_by_name=None): + super(UnexpectedEOF, self).__init__() + + self.expected = expected + self.state = state + from .lexer import Token + self.token = Token("", "") ## + + self.pos_in_stream = -1 + self.line = -1 + self.column = -1 + self._terminals_by_name = terminals_by_name + + + def __str__(self): + message = "Unexpected end-of-input. " + message += self._format_expected(self.expected) + return message + + +class UnexpectedCharacters(LexError, UnexpectedInput): + #-- + + allowed: Set[str] + considered_tokens: Set[Any] + + def __init__(self, seq, lex_pos, line, column, allowed=None, considered_tokens=None, state=None, token_history=None, + terminals_by_name=None, considered_rules=None): + super(UnexpectedCharacters, self).__init__() + + ## + + self.line = line + self.column = column + self.pos_in_stream = lex_pos + self.state = state + self._terminals_by_name = terminals_by_name + + self.allowed = allowed + self.considered_tokens = considered_tokens + self.considered_rules = considered_rules + self.token_history = token_history + + if isinstance(seq, bytes): + self.char = seq[lex_pos:lex_pos + 1].decode("ascii", "backslashreplace") + else: + self.char = seq[lex_pos] + self._context = self.get_context(seq) + + + def __str__(self): + message = "No terminal matches '%s' in the current parser context, at line %d col %d" % (self.char, self.line, self.column) + message += '\n\n' + self._context + if self.allowed: + message += self._format_expected(self.allowed) + if self.token_history: + message += '\nPrevious tokens: %s\n' % ', '.join(repr(t) for t in self.token_history) + return message + + +class UnexpectedToken(ParseError, UnexpectedInput): + #-- + + expected: Set[str] + considered_rules: Set[str] + + def __init__(self, token, expected, considered_rules=None, state=None, interactive_parser=None, terminals_by_name=None, token_history=None): + super(UnexpectedToken, self).__init__() + + ## + + self.line = getattr(token, 'line', '?') + self.column = getattr(token, 'column', '?') + self.pos_in_stream = getattr(token, 'start_pos', None) + self.state = state + + self.token = token + self.expected = expected ## + + self._accepts = NO_VALUE + self.considered_rules = considered_rules + self.interactive_parser = interactive_parser + self._terminals_by_name = terminals_by_name + self.token_history = token_history + + + @property + def accepts(self) -> Set[str]: + if self._accepts is NO_VALUE: + self._accepts = self.interactive_parser and self.interactive_parser.accepts() + return self._accepts + + def __str__(self): + message = ("Unexpected token %r at line %s, column %s.\n%s" + % (self.token, self.line, self.column, self._format_expected(self.accepts or self.expected))) + if self.token_history: + message += "Previous tokens: %r\n" % self.token_history + + return message + + + +class VisitError(LarkError): + #-- + + obj: 'Union[Tree, Token]' + orig_exc: Exception + + def __init__(self, rule, obj, orig_exc): + message = 'Error trying to process rule "%s":\n\n%s' % (rule, orig_exc) + super(VisitError, self).__init__(message) + + self.rule = rule + self.obj = obj + self.orig_exc = orig_exc + + +class MissingVariableError(LarkError): + pass + + +import sys, re +import logging +from dataclasses import dataclass +from typing import Generic, AnyStr + +logger: logging.Logger = logging.getLogger("lark") +logger.addHandler(logging.StreamHandler()) +## + +## + +logger.setLevel(logging.CRITICAL) + + +NO_VALUE = object() + +T = TypeVar("T") + + +def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict: + d: Dict[Any, Any] = {} + for item in seq: + k = key(item) if (key is not None) else item + v = value(item) if (value is not None) else item + try: + d[k].append(v) + except KeyError: + d[k] = [v] + return d + + +def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any: + if isinstance(data, dict): + if '__type__' in data: ## + + class_ = namespace[data['__type__']] + return class_.deserialize(data, memo) + elif '@' in data: + return memo[data['@']] + return {key:_deserialize(value, namespace, memo) for key, value in data.items()} + elif isinstance(data, list): + return [_deserialize(value, namespace, memo) for value in data] + return data + + +_T = TypeVar("_T", bound="Serialize") + +class Serialize: + #-- + + def memo_serialize(self, types_to_memoize: List) -> Any: + memo = SerializeMemoizer(types_to_memoize) + return self.serialize(memo), memo.serialize() + + def serialize(self, memo = None) -> Dict[str, Any]: + if memo and memo.in_types(self): + return {'@': memo.memoized.get(self)} + + fields = getattr(self, '__serialize_fields__') + res = {f: _serialize(getattr(self, f), memo) for f in fields} + res['__type__'] = type(self).__name__ + if hasattr(self, '_serialize'): + self._serialize(res, memo) + return res + + @classmethod + def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T: + namespace = getattr(cls, '__serialize_namespace__', []) + namespace = {c.__name__:c for c in namespace} + + fields = getattr(cls, '__serialize_fields__') + + if '@' in data: + return memo[data['@']] + + inst = cls.__new__(cls) + for f in fields: + try: + setattr(inst, f, _deserialize(data[f], namespace, memo)) + except KeyError as e: + raise KeyError("Cannot find key for class", cls, e) + + if hasattr(inst, '_deserialize'): + inst._deserialize() + + return inst + + +class SerializeMemoizer(Serialize): + #-- + + __serialize_fields__ = 'memoized', + + def __init__(self, types_to_memoize: List) -> None: + self.types_to_memoize = tuple(types_to_memoize) + self.memoized = Enumerator() + + def in_types(self, value: Serialize) -> bool: + return isinstance(value, self.types_to_memoize) + + def serialize(self) -> Dict[int, Any]: ## + + return _serialize(self.memoized.reversed(), None) + + @classmethod + def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: ## + + return _deserialize(data, namespace, memo) + + +try: + import regex + _has_regex = True +except ImportError: + _has_regex = False + +if sys.version_info >= (3, 11): + import re._parser as sre_parse + import re._constants as sre_constants +else: + import sre_parse + import sre_constants + +categ_pattern = re.compile(r'\\p{[A-Za-z_]+}') + +def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]: + if _has_regex: + ## + + ## + + ## + + regexp_final = re.sub(categ_pattern, 'A', expr) + else: + if re.search(categ_pattern, expr): + raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr) + regexp_final = expr + try: + ## + + return [int(x) for x in sre_parse.parse(regexp_final).getwidth()] + except sre_constants.error: + if not _has_regex: + raise ValueError(expr) + else: + ## + + ## + + c = regex.compile(regexp_final) + ## + + ## + + MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT) + if c.match('') is None: + ## + + return 1, int(MAXWIDTH) + else: + return 0, int(MAXWIDTH) + + +@dataclass(frozen=True) +class TextSlice(Generic[AnyStr]): + #-- + text: AnyStr + start: int + end: int + + def __post_init__(self): + if not isinstance(self.text, (str, bytes)): + raise TypeError("text must be str or bytes") + + if self.start < 0: + object.__setattr__(self, 'start', self.start + len(self.text)) + assert self.start >=0 + + if self.end is None: + object.__setattr__(self, 'end', len(self.text)) + elif self.end < 0: + object.__setattr__(self, 'end', self.end + len(self.text)) + assert self.end <= len(self.text) + + @classmethod + def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]': + if isinstance(text, TextSlice): + return text + + return cls(text, 0, len(text)) + + def is_complete_text(self): + return self.start == 0 and self.end == len(self.text) + + def __len__(self): + return self.end - self.start + + def count(self, substr: AnyStr): + return self.text.count(substr, self.start, self.end) + + def rindex(self, substr: AnyStr): + return self.text.rindex(substr, self.start, self.end) + + +TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]'] +LarkInput = Union[AnyStr, TextSlice[AnyStr], Any] + + + +class Meta: + + empty: bool + line: int + column: int + start_pos: int + end_line: int + end_column: int + end_pos: int + orig_expansion: 'List[TerminalDef]' + match_tree: bool + + def __init__(self): + self.empty = True + + +_Leaf_T = TypeVar("_Leaf_T") +Branch = Union[_Leaf_T, 'Tree[_Leaf_T]'] + + +class Tree(Generic[_Leaf_T]): + #-- + + data: str + children: 'List[Branch[_Leaf_T]]' + + def __init__(self, data: str, children: 'List[Branch[_Leaf_T]]', meta: Optional[Meta]=None) -> None: + self.data = data + self.children = children + self._meta = meta + + @property + def meta(self) -> Meta: + if self._meta is None: + self._meta = Meta() + return self._meta + + def __repr__(self): + return 'Tree(%r, %r)' % (self.data, self.children) + + __match_args__ = ("data", "children") + + def _pretty_label(self): + return self.data + + def _pretty(self, level, indent_str): + yield f'{indent_str*level}{self._pretty_label()}' + if len(self.children) == 1 and not isinstance(self.children[0], Tree): + yield f'\t{self.children[0]}\n' + else: + yield '\n' + for n in self.children: + if isinstance(n, Tree): + yield from n._pretty(level+1, indent_str) + else: + yield f'{indent_str*(level+1)}{n}\n' + + def pretty(self, indent_str: str=' ') -> str: + #-- + return ''.join(self._pretty(0, indent_str)) + + def __rich__(self, parent:Optional['rich.tree.Tree']=None) -> 'rich.tree.Tree': + #-- + return self._rich(parent) + + def _rich(self, parent): + if parent: + tree = parent.add(f'[bold]{self.data}[/bold]') + else: + import rich.tree + tree = rich.tree.Tree(self.data) + + for c in self.children: + if isinstance(c, Tree): + c._rich(tree) + else: + tree.add(f'[green]{c}[/green]') + + return tree + + def __eq__(self, other): + try: + return self.data == other.data and self.children == other.children + except AttributeError: + return False + + def __ne__(self, other): + return not (self == other) + + def __hash__(self) -> int: + return hash((self.data, tuple(self.children))) + + def iter_subtrees(self) -> 'Iterator[Tree[_Leaf_T]]': + #-- + queue = [self] + subtrees = dict() + for subtree in queue: + subtrees[id(subtree)] = subtree + queue += [c for c in reversed(subtree.children) + if isinstance(c, Tree) and id(c) not in subtrees] + + del queue + return reversed(list(subtrees.values())) + + def iter_subtrees_topdown(self): + #-- + stack = [self] + stack_append = stack.append + stack_pop = stack.pop + while stack: + node = stack_pop() + if not isinstance(node, Tree): + continue + yield node + for child in reversed(node.children): + stack_append(child) + + def find_pred(self, pred: 'Callable[[Tree[_Leaf_T]], bool]') -> 'Iterator[Tree[_Leaf_T]]': + #-- + return filter(pred, self.iter_subtrees()) + + def find_data(self, data: str) -> 'Iterator[Tree[_Leaf_T]]': + #-- + return self.find_pred(lambda t: t.data == data) + + +from functools import wraps, update_wrapper +from inspect import getmembers, getmro + +_Return_T = TypeVar('_Return_T') +_Return_V = TypeVar('_Return_V') +_Leaf_T = TypeVar('_Leaf_T') +_Leaf_U = TypeVar('_Leaf_U') +_R = TypeVar('_R') +_FUNC = Callable[..., _Return_T] +_DECORATED = Union[_FUNC, type] + +class _DiscardType: + #-- + + def __repr__(self): + return "lark.visitors.Discard" + +Discard = _DiscardType() + +## + + +class _Decoratable: + #-- + + @classmethod + def _apply_v_args(cls, visit_wrapper): + mro = getmro(cls) + assert mro[0] is cls + libmembers = {name for _cls in mro[1:] for name, _ in getmembers(_cls)} + for name, value in getmembers(cls): + + ## + + if name.startswith('_') or (name in libmembers and name not in cls.__dict__): + continue + if not callable(value): + continue + + ## + + if isinstance(cls.__dict__[name], _VArgsWrapper): + continue + + setattr(cls, name, _VArgsWrapper(cls.__dict__[name], visit_wrapper)) + return cls + + def __class_getitem__(cls, _): + return cls + + +class Transformer(_Decoratable, ABC, Generic[_Leaf_T, _Return_T]): + #-- + __visit_tokens__ = True ## + + + def __init__(self, visit_tokens: bool=True) -> None: + self.__visit_tokens__ = visit_tokens + + def _call_userfunc(self, tree, new_children=None): + ## + + children = new_children if new_children is not None else tree.children + try: + f = getattr(self, tree.data) + except AttributeError: + return self.__default__(tree.data, children, tree.meta) + else: + try: + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, children, tree.meta) + else: + return f(children) + except GrammarError: + raise + except Exception as e: + raise VisitError(tree.data, tree, e) + + def _call_userfunc_token(self, token): + try: + f = getattr(self, token.type) + except AttributeError: + return self.__default_token__(token) + else: + try: + return f(token) + except GrammarError: + raise + except Exception as e: + raise VisitError(token.type, token, e) + + def _transform_children(self, children): + for c in children: + if isinstance(c, Tree): + res = self._transform_tree(c) + elif self.__visit_tokens__ and isinstance(c, Token): + res = self._call_userfunc_token(c) + else: + res = c + + if res is not Discard: + yield res + + def _transform_tree(self, tree): + children = list(self._transform_children(tree.children)) + return self._call_userfunc(tree, children) + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + #-- + res = list(self._transform_children([tree])) + if not res: + return None ## + + assert len(res) == 1 + return res[0] + + def __mul__( + self: 'Transformer[_Leaf_T, Tree[_Leaf_U]]', + other: 'Union[Transformer[_Leaf_U, _Return_V], TransformerChain[_Leaf_U, _Return_V,]]' + ) -> 'TransformerChain[_Leaf_T, _Return_V]': + #-- + return TransformerChain(self, other) + + def __default__(self, data, children, meta): + #-- + return Tree(data, children, meta) + + def __default_token__(self, token): + #-- + return token + + +def merge_transformers(base_transformer=None, **transformers_to_merge): + #-- + if base_transformer is None: + base_transformer = Transformer() + for prefix, transformer in transformers_to_merge.items(): + for method_name in dir(transformer): + method = getattr(transformer, method_name) + if not callable(method): + continue + if method_name.startswith("_") or method_name == "transform": + continue + prefixed_method = prefix + "__" + method_name + if hasattr(base_transformer, prefixed_method): + raise AttributeError("Cannot merge: method '%s' appears more than once" % prefixed_method) + + setattr(base_transformer, prefixed_method, method) + + return base_transformer + + +class InlineTransformer(Transformer): ## + + def _call_userfunc(self, tree, new_children=None): + ## + + children = new_children if new_children is not None else tree.children + try: + f = getattr(self, tree.data) + except AttributeError: + return self.__default__(tree.data, children, tree.meta) + else: + return f(*children) + + +class TransformerChain(Generic[_Leaf_T, _Return_T]): + + transformers: 'Tuple[Union[Transformer, TransformerChain], ...]' + + def __init__(self, *transformers: 'Union[Transformer, TransformerChain]') -> None: + self.transformers = transformers + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + for t in self.transformers: + tree = t.transform(tree) + return cast(_Return_T, tree) + + def __mul__( + self: 'TransformerChain[_Leaf_T, Tree[_Leaf_U]]', + other: 'Union[Transformer[_Leaf_U, _Return_V], TransformerChain[_Leaf_U, _Return_V]]' + ) -> 'TransformerChain[_Leaf_T, _Return_V]': + return TransformerChain(*self.transformers + (other,)) + + +class Transformer_InPlace(Transformer[_Leaf_T, _Return_T]): + #-- + def _transform_tree(self, tree): ## + + return self._call_userfunc(tree) + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + for subtree in tree.iter_subtrees(): + subtree.children = list(self._transform_children(subtree.children)) + + return self._transform_tree(tree) + + +class Transformer_NonRecursive(Transformer[_Leaf_T, _Return_T]): + #-- + + def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: + ## + + rev_postfix = [] + q: List[Branch[_Leaf_T]] = [tree] + while q: + t = q.pop() + rev_postfix.append(t) + if isinstance(t, Tree): + q += t.children + + ## + + stack: List = [] + for x in reversed(rev_postfix): + if isinstance(x, Tree): + size = len(x.children) + if size: + args = stack[-size:] + del stack[-size:] + else: + args = [] + + res = self._call_userfunc(x, args) + if res is not Discard: + stack.append(res) + + elif self.__visit_tokens__ and isinstance(x, Token): + res = self._call_userfunc_token(x) + if res is not Discard: + stack.append(res) + else: + stack.append(x) + + result, = stack ## + + ## + + ## + + ## + + return cast(_Return_T, result) + + +class Transformer_InPlaceRecursive(Transformer[_Leaf_T, _Return_T]): + #-- + def _transform_tree(self, tree): + tree.children = list(self._transform_children(tree.children)) + return self._call_userfunc(tree) + + +## + + +class VisitorBase: + def _call_userfunc(self, tree): + return getattr(self, tree.data, self.__default__)(tree) + + def __default__(self, tree): + #-- + return tree + + def __class_getitem__(cls, _): + return cls + + +class Visitor(VisitorBase, ABC, Generic[_Leaf_T]): + #-- + + def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + #-- + for subtree in tree.iter_subtrees(): + self._call_userfunc(subtree) + return tree + + def visit_topdown(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + #-- + for subtree in tree.iter_subtrees_topdown(): + self._call_userfunc(subtree) + return tree + + +class Visitor_Recursive(VisitorBase, Generic[_Leaf_T]): + #-- + + def visit(self, tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + #-- + for child in tree.children: + if isinstance(child, Tree): + self.visit(child) + + self._call_userfunc(tree) + return tree + + def visit_topdown(self,tree: Tree[_Leaf_T]) -> Tree[_Leaf_T]: + #-- + self._call_userfunc(tree) + + for child in tree.children: + if isinstance(child, Tree): + self.visit_topdown(child) + + return tree + + +class Interpreter(_Decoratable, ABC, Generic[_Leaf_T, _Return_T]): + #-- + + def visit(self, tree: Tree[_Leaf_T]) -> _Return_T: + ## + + ## + + ## + + return self._visit_tree(tree) + + def _visit_tree(self, tree: Tree[_Leaf_T]): + f = getattr(self, tree.data) + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + return f.visit_wrapper(f, tree.data, tree.children, tree.meta) + else: + return f(tree) + + def visit_children(self, tree: Tree[_Leaf_T]) -> List: + return [self._visit_tree(child) if isinstance(child, Tree) else child + for child in tree.children] + + def __getattr__(self, name): + return self.__default__ + + def __default__(self, tree): + return self.visit_children(tree) + + +_InterMethod = Callable[[Type[Interpreter], _Return_T], _R] + +def visit_children_decor(func: _InterMethod) -> _InterMethod: + #-- + @wraps(func) + def inner(cls, tree): + values = cls.visit_children(tree) + return func(cls, values) + return inner + +## + + +def _apply_v_args(obj, visit_wrapper): + try: + _apply = obj._apply_v_args + except AttributeError: + return _VArgsWrapper(obj, visit_wrapper) + else: + return _apply(visit_wrapper) + + +class _VArgsWrapper: + #-- + base_func: Callable + + def __init__(self, func: Callable, visit_wrapper: Callable[[Callable, str, list, Any], Any]): + if isinstance(func, _VArgsWrapper): + func = func.base_func + self.base_func = func + self.visit_wrapper = visit_wrapper + update_wrapper(self, func) + + def __call__(self, *args, **kwargs): + return self.base_func(*args, **kwargs) + + def __get__(self, instance, owner=None): + try: + ## + + ## + + g = type(self.base_func).__get__ + except AttributeError: + return self + else: + return _VArgsWrapper(g(self.base_func, instance, owner), self.visit_wrapper) + + def __set_name__(self, owner, name): + try: + f = type(self.base_func).__set_name__ + except AttributeError: + return + else: + f(self.base_func, owner, name) + + +def _vargs_inline(f, _data, children, _meta): + return f(*children) +def _vargs_meta_inline(f, _data, children, meta): + return f(meta, *children) +def _vargs_meta(f, _data, children, meta): + return f(meta, children) +def _vargs_tree(f, data, children, meta): + return f(Tree(data, children, meta)) + + +def v_args(inline: bool = False, meta: bool = False, tree: bool = False, wrapper: Optional[Callable] = None) -> Callable[[_DECORATED], _DECORATED]: + #-- + if tree and (meta or inline): + raise ValueError("Visitor functions cannot combine 'tree' with 'meta' or 'inline'.") + + func = None + if meta: + if inline: + func = _vargs_meta_inline + else: + func = _vargs_meta + elif inline: + func = _vargs_inline + elif tree: + func = _vargs_tree + + if wrapper is not None: + if func is not None: + raise ValueError("Cannot use 'wrapper' along with 'tree', 'meta' or 'inline'.") + func = wrapper + + def _visitor_args_dec(obj): + return _apply_v_args(obj, func) + return _visitor_args_dec + + + +TOKEN_DEFAULT_PRIORITY = 0 + + +class Symbol(Serialize): + __slots__ = ('name',) + + name: str + is_term: ClassVar[bool] = NotImplemented + + def __init__(self, name: str) -> None: + self.name = name + + def __eq__(self, other): + if not isinstance(other, Symbol): + return NotImplemented + return self.is_term == other.is_term and self.name == other.name + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.name) + + def __repr__(self): + return '%s(%r)' % (type(self).__name__, self.name) + + fullrepr = property(__repr__) + + def renamed(self, f): + return type(self)(f(self.name)) + + +class Terminal(Symbol): + __serialize_fields__ = 'name', 'filter_out' + + is_term: ClassVar[bool] = True + + def __init__(self, name: str, filter_out: bool = False) -> None: + self.name = name + self.filter_out = filter_out + + @property + def fullrepr(self): + return '%s(%r, %r)' % (type(self).__name__, self.name, self.filter_out) + + def renamed(self, f): + return type(self)(f(self.name), self.filter_out) + + +class NonTerminal(Symbol): + __serialize_fields__ = 'name', + + is_term: ClassVar[bool] = False + + def serialize(self, memo=None) -> Dict[str, Any]: + ## + + ## + + return {'name': str(self.name), '__type__': 'NonTerminal'} + + +class RuleOptions(Serialize): + __serialize_fields__ = 'keep_all_tokens', 'expand1', 'priority', 'template_source', 'empty_indices' + + keep_all_tokens: bool + expand1: bool + priority: Optional[int] + template_source: Optional[str] + empty_indices: Tuple[bool, ...] + + def __init__(self, keep_all_tokens: bool=False, expand1: bool=False, priority: Optional[int]=None, template_source: Optional[str]=None, empty_indices: Tuple[bool, ...]=()) -> None: + self.keep_all_tokens = keep_all_tokens + self.expand1 = expand1 + self.priority = priority + self.template_source = template_source + self.empty_indices = empty_indices + + def __repr__(self): + return 'RuleOptions(%r, %r, %r, %r)' % ( + self.keep_all_tokens, + self.expand1, + self.priority, + self.template_source + ) + + +class Rule(Serialize): + #-- + __slots__ = ('origin', 'expansion', 'alias', 'options', 'order', '_hash') + + __serialize_fields__ = 'origin', 'expansion', 'order', 'alias', 'options' + __serialize_namespace__ = Terminal, NonTerminal, RuleOptions + + origin: NonTerminal + expansion: Sequence[Symbol] + order: int + alias: Optional[str] + options: RuleOptions + _hash: int + + def __init__(self, origin: NonTerminal, expansion: Sequence[Symbol], + order: int=0, alias: Optional[str]=None, options: Optional[RuleOptions]=None): + self.origin = origin + self.expansion = expansion + self.alias = alias + self.order = order + self.options = options or RuleOptions() + self._hash = hash((self.origin, tuple(self.expansion))) + + def _deserialize(self): + self._hash = hash((self.origin, tuple(self.expansion))) + + def __str__(self): + return '<%s : %s>' % (self.origin.name, ' '.join(x.name for x in self.expansion)) + + def __repr__(self): + return 'Rule(%r, %r, %r, %r)' % (self.origin, self.expansion, self.alias, self.options) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if not isinstance(other, Rule): + return False + return self.origin == other.origin and self.expansion == other.expansion + + + +from contextlib import suppress +from copy import copy + +try: ## + + has_interegular = bool(interegular) +except NameError: + has_interegular = False + +class Pattern(Serialize, ABC): + #-- + + value: str + flags: Collection[str] + raw: Optional[str] + type: ClassVar[str] + + def __init__(self, value: str, flags: Collection[str] = (), raw: Optional[str] = None) -> None: + self.value = value + self.flags = frozenset(flags) + self.raw = raw + + def __repr__(self): + return repr(self.to_regexp()) + + ## + + def __hash__(self): + return hash((type(self), self.value, self.flags)) + + def __eq__(self, other): + return type(self) == type(other) and self.value == other.value and self.flags == other.flags + + @abstractmethod + def to_regexp(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def min_width(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def max_width(self) -> int: + raise NotImplementedError() + + def _get_flags(self, value): + for f in self.flags: + value = ('(?%s:%s)' % (f, value)) + return value + + +class PatternStr(Pattern): + __serialize_fields__ = 'value', 'flags', 'raw' + + type: ClassVar[str] = "str" + + def to_regexp(self) -> str: + return self._get_flags(re.escape(self.value)) + + @property + def min_width(self) -> int: + return len(self.value) + + @property + def max_width(self) -> int: + return len(self.value) + + +class PatternRE(Pattern): + __serialize_fields__ = 'value', 'flags', 'raw', '_width' + + type: ClassVar[str] = "re" + + def to_regexp(self) -> str: + return self._get_flags(self.value) + + _width = None + def _get_width(self): + if self._width is None: + self._width = get_regexp_width(self.to_regexp()) + return self._width + + @property + def min_width(self) -> int: + return self._get_width()[0] + + @property + def max_width(self) -> int: + return self._get_width()[1] + + +class TerminalDef(Serialize): + #-- + __serialize_fields__ = 'name', 'pattern', 'priority' + __serialize_namespace__ = PatternStr, PatternRE + + name: str + pattern: Pattern + priority: int + + def __init__(self, name: str, pattern: Pattern, priority: int = TOKEN_DEFAULT_PRIORITY) -> None: + assert isinstance(pattern, Pattern), pattern + self.name = name + self.pattern = pattern + self.priority = priority + + def __repr__(self): + return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern) + + def user_repr(self) -> str: + if self.name.startswith('__'): ## + + return self.pattern.raw or self.name + else: + return self.name + +_T = TypeVar('_T', bound="Token") + +class Token(str): + #-- + __slots__ = ('type', 'start_pos', 'value', 'line', 'column', 'end_line', 'end_column', 'end_pos') + + __match_args__ = ('type', 'value') + + type: str + start_pos: Optional[int] + value: Any + line: Optional[int] + column: Optional[int] + end_line: Optional[int] + end_column: Optional[int] + end_pos: Optional[int] + + + @overload + def __new__( + cls, + type: str, + value: Any, + start_pos: Optional[int] = None, + line: Optional[int] = None, + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + end_pos: Optional[int] = None + ) -> 'Token': + ... + + @overload + def __new__( + cls, + type_: str, + value: Any, + start_pos: Optional[int] = None, + line: Optional[int] = None, + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + end_pos: Optional[int] = None + ) -> 'Token': ... + + def __new__(cls, *args, **kwargs): + if "type_" in kwargs: + warnings.warn("`type_` is deprecated use `type` instead", DeprecationWarning) + + if "type" in kwargs: + raise TypeError("Error: using both 'type' and the deprecated 'type_' as arguments.") + kwargs["type"] = kwargs.pop("type_") + + return cls._future_new(*args, **kwargs) + + + @classmethod + def _future_new(cls, type, value, start_pos=None, line=None, column=None, end_line=None, end_column=None, end_pos=None): + inst = super(Token, cls).__new__(cls, value) + + inst.type = type + inst.start_pos = start_pos + inst.value = value + inst.line = line + inst.column = column + inst.end_line = end_line + inst.end_column = end_column + inst.end_pos = end_pos + return inst + + @overload + def update(self, type: Optional[str] = None, value: Optional[Any] = None) -> 'Token': + ... + + @overload + def update(self, type_: Optional[str] = None, value: Optional[Any] = None) -> 'Token': + ... + + def update(self, *args, **kwargs): + if "type_" in kwargs: + warnings.warn("`type_` is deprecated use `type` instead", DeprecationWarning) + + if "type" in kwargs: + raise TypeError("Error: using both 'type' and the deprecated 'type_' as arguments.") + kwargs["type"] = kwargs.pop("type_") + + return self._future_update(*args, **kwargs) + + def _future_update(self, type: Optional[str] = None, value: Optional[Any] = None) -> 'Token': + return Token.new_borrow_pos( + type if type is not None else self.type, + value if value is not None else self.value, + self + ) + + @classmethod + def new_borrow_pos(cls: Type[_T], type_: str, value: Any, borrow_t: 'Token') -> _T: + return cls(type_, value, borrow_t.start_pos, borrow_t.line, borrow_t.column, borrow_t.end_line, borrow_t.end_column, borrow_t.end_pos) + + def __reduce__(self): + return (self.__class__, (self.type, self.value, self.start_pos, self.line, self.column)) + + def __repr__(self): + return 'Token(%r, %r)' % (self.type, self.value) + + def __deepcopy__(self, memo): + return Token(self.type, self.value, self.start_pos, self.line, self.column) + + def __eq__(self, other): + if isinstance(other, Token) and self.type != other.type: + return False + + return str.__eq__(self, other) + + __hash__ = str.__hash__ + + +class LineCounter: + #-- + + __slots__ = 'char_pos', 'line', 'column', 'line_start_pos', 'newline_char' + + def __init__(self, newline_char): + self.newline_char = newline_char + self.char_pos = 0 + self.line = 1 + self.column = 1 + self.line_start_pos = 0 + + def __eq__(self, other): + if not isinstance(other, LineCounter): + return NotImplemented + + return self.char_pos == other.char_pos and self.newline_char == other.newline_char + + def feed(self, token: TextOrSlice, test_newline=True): + #-- + if test_newline: + newlines = token.count(self.newline_char) + if newlines: + self.line += newlines + self.line_start_pos = self.char_pos + token.rindex(self.newline_char) + 1 + + self.char_pos += len(token) + self.column = self.char_pos - self.line_start_pos + 1 + + +class UnlessCallback: + def __init__(self, scanner: 'Scanner'): + self.scanner = scanner + + def __call__(self, t: Token): + res = self.scanner.fullmatch(t.value) + if res is not None: + t.type = res + return t + + +class CallChain: + def __init__(self, callback1, callback2, cond): + self.callback1 = callback1 + self.callback2 = callback2 + self.cond = cond + + def __call__(self, t): + t2 = self.callback1(t) + return self.callback2(t) if self.cond(t2) else t2 + + +def _get_match(re_, regexp, s, flags): + m = re_.match(regexp, s, flags) + if m: + return m.group(0) + +def _create_unless(terminals, g_regex_flags, re_, use_bytes): + tokens_by_type = classify(terminals, lambda t: type(t.pattern)) + assert len(tokens_by_type) <= 2, tokens_by_type.keys() + embedded_strs = set() + callback = {} + for retok in tokens_by_type.get(PatternRE, []): + unless = [] + for strtok in tokens_by_type.get(PatternStr, []): + if strtok.priority != retok.priority: + continue + s = strtok.pattern.value + if s == _get_match(re_, retok.pattern.to_regexp(), s, g_regex_flags): + unless.append(strtok) + if strtok.pattern.flags <= retok.pattern.flags: + embedded_strs.add(strtok) + if unless: + callback[retok.name] = UnlessCallback(Scanner(unless, g_regex_flags, re_, use_bytes=use_bytes)) + + new_terminals = [t for t in terminals if t not in embedded_strs] + return new_terminals, callback + + +class Scanner: + def __init__(self, terminals, g_regex_flags, re_, use_bytes): + self.terminals = terminals + self.g_regex_flags = g_regex_flags + self.re_ = re_ + self.use_bytes = use_bytes + + self.allowed_types = {t.name for t in self.terminals} + + self._mres = self._build_mres(terminals, len(terminals)) + + def _build_mres(self, terminals, max_size): + ## + + ## + + ## + + mres = [] + while terminals: + pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp()) for t in terminals[:max_size]) + if self.use_bytes: + pattern = pattern.encode('latin-1') + try: + mre = self.re_.compile(pattern, self.g_regex_flags) + except AssertionError: ## + + return self._build_mres(terminals, max_size // 2) + + mres.append(mre) + terminals = terminals[max_size:] + return mres + + def match(self, text: TextSlice, pos): + for mre in self._mres: + m = mre.match(text.text, pos, text.end) + if m: + return m.group(0), m.lastgroup + + + def fullmatch(self, text: str) -> Optional[str]: + for mre in self._mres: + m = mre.fullmatch(text) + if m: + return m.lastgroup + return None + +def _regexp_has_newline(r: str): + #-- + return '\n' in r or '\\n' in r or '\\s' in r or '[^' in r or ('(?s' in r and '.' in r) + + +class LexerState: + #-- + + __slots__ = 'text', 'line_ctr', 'last_token' + + text: TextSlice + line_ctr: LineCounter + last_token: Optional[Token] + + def __init__(self, text: TextSlice, line_ctr: Optional[LineCounter] = None, last_token: Optional[Token]=None): + if isinstance(text, TextSlice): + if line_ctr is None: + line_ctr = LineCounter(b'\n' if isinstance(text.text, bytes) else '\n') + + if text.start > 0: + ## + + line_ctr.feed(TextSlice(text.text, 0, text.start)) + + if not (text.start <= line_ctr.char_pos <= text.end): + raise ValueError("LineCounter.char_pos is out of bounds") + + self.text = text + self.line_ctr = line_ctr + self.last_token = last_token + + + def __eq__(self, other): + if not isinstance(other, LexerState): + return NotImplemented + + return self.text == other.text and self.line_ctr == other.line_ctr and self.last_token == other.last_token + + def __copy__(self): + return type(self)(self.text, copy(self.line_ctr), self.last_token) + + +class LexerThread: + #-- + + def __init__(self, lexer: 'Lexer', lexer_state: Optional[LexerState]): + self.lexer = lexer + self.state = lexer_state + + @classmethod + def from_text(cls, lexer: 'Lexer', text_or_slice: TextOrSlice) -> 'LexerThread': + text = TextSlice.cast_from(text_or_slice) + return cls(lexer, LexerState(text)) + + @classmethod + def from_custom_input(cls, lexer: 'Lexer', text: Any) -> 'LexerThread': + return cls(lexer, LexerState(text)) + + def lex(self, parser_state): + if self.state is None: + raise TypeError("Cannot lex: No text assigned to lexer state") + return self.lexer.lex(self.state, parser_state) + + def __copy__(self): + return type(self)(self.lexer, copy(self.state)) + + _Token = Token + + +_Callback = Callable[[Token], Token] + +class Lexer(ABC): + #-- + @abstractmethod + def lex(self, lexer_state: LexerState, parser_state: Any) -> Iterator[Token]: + return NotImplemented + + def make_lexer_state(self, text: str): + #-- + return LexerState(TextSlice.cast_from(text)) + + +def _check_regex_collisions(terminal_to_regexp: Dict[TerminalDef, str], comparator, strict_mode, max_collisions_to_show=8): + if not comparator: + comparator = interegular.Comparator.from_regexes(terminal_to_regexp) + + ## + + ## + + max_time = 2 if strict_mode else 0.2 + + ## + + if comparator.count_marked_pairs() >= max_collisions_to_show: + return + for group in classify(terminal_to_regexp, lambda t: t.priority).values(): + for a, b in comparator.check(group, skip_marked=True): + assert a.priority == b.priority + ## + + comparator.mark(a, b) + + ## + + message = f"Collision between Terminals {a.name} and {b.name}. " + try: + example = comparator.get_example_overlap(a, b, max_time).format_multiline() + except ValueError: + ## + + example = "No example could be found fast enough. However, the collision does still exists" + if strict_mode: + raise LexError(f"{message}\n{example}") + logger.warning("%s The lexer will choose between them arbitrarily.\n%s", message, example) + if comparator.count_marked_pairs() >= max_collisions_to_show: + logger.warning("Found 8 regex collisions, will not check for more.") + return + + +class AbstractBasicLexer(Lexer): + terminals_by_name: Dict[str, TerminalDef] + + @abstractmethod + def __init__(self, conf: 'LexerConf', comparator=None) -> None: + ... + + @abstractmethod + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + ... + + def lex(self, state: LexerState, parser_state: Any) -> Iterator[Token]: + with suppress(EOFError): + while True: + yield self.next_token(state, parser_state) + + +class BasicLexer(AbstractBasicLexer): + terminals: Collection[TerminalDef] + ignore_types: FrozenSet[str] + newline_types: FrozenSet[str] + user_callbacks: Dict[str, _Callback] + callback: Dict[str, _Callback] + re: ModuleType + + def __init__(self, conf: 'LexerConf', comparator=None) -> None: + terminals = list(conf.terminals) + assert all(isinstance(t, TerminalDef) for t in terminals), terminals + + self.re = conf.re_module + + if not conf.skip_validation: + ## + + terminal_to_regexp = {} + for t in terminals: + regexp = t.pattern.to_regexp() + try: + self.re.compile(regexp, conf.g_regex_flags) + except self.re.error: + raise LexError("Cannot compile token %s: %s" % (t.name, t.pattern)) + + if t.pattern.min_width == 0: + raise LexError("Lexer does not allow zero-width terminals. (%s: %s)" % (t.name, t.pattern)) + if t.pattern.type == "re": + terminal_to_regexp[t] = regexp + + if not (set(conf.ignore) <= {t.name for t in terminals}): + raise LexError("Ignore terminals are not defined: %s" % (set(conf.ignore) - {t.name for t in terminals})) + + if has_interegular: + _check_regex_collisions(terminal_to_regexp, comparator, conf.strict) + elif conf.strict: + raise LexError("interegular must be installed for strict mode. Use `pip install 'lark[interegular]'`.") + + ## + + self.newline_types = frozenset(t.name for t in terminals if _regexp_has_newline(t.pattern.to_regexp())) + self.ignore_types = frozenset(conf.ignore) + + terminals.sort(key=lambda x: (-x.priority, -x.pattern.max_width, -len(x.pattern.value), x.name)) + self.terminals = terminals + self.user_callbacks = conf.callbacks + self.g_regex_flags = conf.g_regex_flags + self.use_bytes = conf.use_bytes + self.terminals_by_name = conf.terminals_by_name + + self._scanner: Optional[Scanner] = None + + def _build_scanner(self) -> Scanner: + terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, self.re, self.use_bytes) + assert all(self.callback.values()) + + for type_, f in self.user_callbacks.items(): + if type_ in self.callback: + ## + + self.callback[type_] = CallChain(self.callback[type_], f, lambda t: t.type == type_) + else: + self.callback[type_] = f + + return Scanner(terminals, self.g_regex_flags, self.re, self.use_bytes) + + @property + def scanner(self) -> Scanner: + if self._scanner is None: + self._scanner = self._build_scanner() + return self._scanner + + def match(self, text, pos): + return self.scanner.match(text, pos) + + def next_token(self, lex_state: LexerState, parser_state: Any = None) -> Token: + line_ctr = lex_state.line_ctr + while line_ctr.char_pos < lex_state.text.end: + res = self.match(lex_state.text, line_ctr.char_pos) + if not res: + allowed = self.scanner.allowed_types - self.ignore_types + if not allowed: + allowed = {""} + raise UnexpectedCharacters(lex_state.text.text, line_ctr.char_pos, line_ctr.line, line_ctr.column, + allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token], + state=parser_state, terminals_by_name=self.terminals_by_name) + + value, type_ = res + + ignored = type_ in self.ignore_types + t = None + if not ignored or type_ in self.callback: + t = Token(type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column) + line_ctr.feed(value, type_ in self.newline_types) + if t is not None: + t.end_line = line_ctr.line + t.end_column = line_ctr.column + t.end_pos = line_ctr.char_pos + if t.type in self.callback: + t = self.callback[t.type](t) + if not ignored: + if not isinstance(t, Token): + raise LexError("Callbacks must return a token (returned %r)" % t) + lex_state.last_token = t + return t + + ## + + raise EOFError(self) + + +class ContextualLexer(Lexer): + lexers: Dict[int, AbstractBasicLexer] + root_lexer: AbstractBasicLexer + + BasicLexer: Type[AbstractBasicLexer] = BasicLexer + + def __init__(self, conf: 'LexerConf', states: Dict[int, Collection[str]], always_accept: Collection[str]=()) -> None: + terminals = list(conf.terminals) + terminals_by_name = conf.terminals_by_name + + trad_conf = copy(conf) + trad_conf.terminals = terminals + + if has_interegular and not conf.skip_validation: + comparator = interegular.Comparator.from_regexes({t: t.pattern.to_regexp() for t in terminals}) + else: + comparator = None + lexer_by_tokens: Dict[FrozenSet[str], AbstractBasicLexer] = {} + self.lexers = {} + for state, accepts in states.items(): + key = frozenset(accepts) + try: + lexer = lexer_by_tokens[key] + except KeyError: + accepts = set(accepts) | set(conf.ignore) | set(always_accept) + lexer_conf = copy(trad_conf) + lexer_conf.terminals = [terminals_by_name[n] for n in accepts if n in terminals_by_name] + lexer = self.BasicLexer(lexer_conf, comparator) + lexer_by_tokens[key] = lexer + + self.lexers[state] = lexer + + assert trad_conf.terminals is terminals + trad_conf.skip_validation = True ## + + self.root_lexer = self.BasicLexer(trad_conf, comparator) + + def lex(self, lexer_state: LexerState, parser_state: 'ParserState') -> Iterator[Token]: + try: + while True: + lexer = self.lexers[parser_state.position] + yield lexer.next_token(lexer_state, parser_state) + except EOFError: + pass + except UnexpectedCharacters as e: + ## + + ## + + try: + last_token = lexer_state.last_token ## + + token = self.root_lexer.next_token(lexer_state, parser_state) + raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name) + except UnexpectedCharacters: + raise e ## + + + + +_ParserArgType: 'TypeAlias' = 'Literal["earley", "lalr", "cyk", "auto"]' +_LexerArgType: 'TypeAlias' = 'Union[Literal["auto", "basic", "contextual", "dynamic", "dynamic_complete"], Type[Lexer]]' +_LexerCallback = Callable[[Token], Token] +ParserCallbacks = Dict[str, Callable] + +class LexerConf(Serialize): + __serialize_fields__ = 'terminals', 'ignore', 'g_regex_flags', 'use_bytes', 'lexer_type' + __serialize_namespace__ = TerminalDef, + + terminals: Collection[TerminalDef] + re_module: ModuleType + ignore: Collection[str] + postlex: 'Optional[PostLex]' + callbacks: Dict[str, _LexerCallback] + g_regex_flags: int + skip_validation: bool + use_bytes: bool + lexer_type: Optional[_LexerArgType] + strict: bool + + def __init__(self, terminals: Collection[TerminalDef], re_module: ModuleType, ignore: Collection[str]=(), postlex: 'Optional[PostLex]'=None, + callbacks: Optional[Dict[str, _LexerCallback]]=None, g_regex_flags: int=0, skip_validation: bool=False, use_bytes: bool=False, strict: bool=False): + self.terminals = terminals + self.terminals_by_name = {t.name: t for t in self.terminals} + assert len(self.terminals) == len(self.terminals_by_name) + self.ignore = ignore + self.postlex = postlex + self.callbacks = callbacks or {} + self.g_regex_flags = g_regex_flags + self.re_module = re_module + self.skip_validation = skip_validation + self.use_bytes = use_bytes + self.strict = strict + self.lexer_type = None + + def _deserialize(self): + self.terminals_by_name = {t.name: t for t in self.terminals} + + def __deepcopy__(self, memo=None): + return type(self)( + deepcopy(self.terminals, memo), + self.re_module, + deepcopy(self.ignore, memo), + deepcopy(self.postlex, memo), + deepcopy(self.callbacks, memo), + deepcopy(self.g_regex_flags, memo), + deepcopy(self.skip_validation, memo), + deepcopy(self.use_bytes, memo), + ) + +class ParserConf(Serialize): + __serialize_fields__ = 'rules', 'start', 'parser_type' + + rules: List['Rule'] + callbacks: ParserCallbacks + start: List[str] + parser_type: _ParserArgType + + def __init__(self, rules: List['Rule'], callbacks: ParserCallbacks, start: List[str]): + assert isinstance(start, list) + self.rules = rules + self.callbacks = callbacks + self.start = start + + +from functools import partial, wraps +from itertools import product + + +class ExpandSingleChild: + def __init__(self, node_builder): + self.node_builder = node_builder + + def __call__(self, children): + if len(children) == 1: + return children[0] + else: + return self.node_builder(children) + + + +class PropagatePositions: + def __init__(self, node_builder, node_filter=None): + self.node_builder = node_builder + self.node_filter = node_filter + + def __call__(self, children): + res = self.node_builder(children) + + if isinstance(res, Tree): + ## + + ## + + ## + + ## + + + res_meta = res.meta + + first_meta = self._pp_get_meta(children) + if first_meta is not None: + if not hasattr(res_meta, 'line'): + ## + + res_meta.line = getattr(first_meta, 'container_line', first_meta.line) + res_meta.column = getattr(first_meta, 'container_column', first_meta.column) + res_meta.start_pos = getattr(first_meta, 'container_start_pos', first_meta.start_pos) + res_meta.empty = False + + res_meta.container_line = getattr(first_meta, 'container_line', first_meta.line) + res_meta.container_column = getattr(first_meta, 'container_column', first_meta.column) + res_meta.container_start_pos = getattr(first_meta, 'container_start_pos', first_meta.start_pos) + + last_meta = self._pp_get_meta(reversed(children)) + if last_meta is not None: + if not hasattr(res_meta, 'end_line'): + res_meta.end_line = getattr(last_meta, 'container_end_line', last_meta.end_line) + res_meta.end_column = getattr(last_meta, 'container_end_column', last_meta.end_column) + res_meta.end_pos = getattr(last_meta, 'container_end_pos', last_meta.end_pos) + res_meta.empty = False + + res_meta.container_end_line = getattr(last_meta, 'container_end_line', last_meta.end_line) + res_meta.container_end_column = getattr(last_meta, 'container_end_column', last_meta.end_column) + res_meta.container_end_pos = getattr(last_meta, 'container_end_pos', last_meta.end_pos) + + return res + + def _pp_get_meta(self, children): + for c in children: + if self.node_filter is not None and not self.node_filter(c): + continue + if isinstance(c, Tree): + if not c.meta.empty: + return c.meta + elif isinstance(c, Token): + return c + elif hasattr(c, '__lark_meta__'): + return c.__lark_meta__() + +def make_propagate_positions(option): + if callable(option): + return partial(PropagatePositions, node_filter=option) + elif option is True: + return PropagatePositions + elif option is False: + return None + + raise ConfigurationError('Invalid option for propagate_positions: %r' % option) + + +class ChildFilter: + def __init__(self, to_include, append_none, node_builder): + self.node_builder = node_builder + self.to_include = to_include + self.append_none = append_none + + def __call__(self, children): + filtered = [] + + for i, to_expand, add_none in self.to_include: + if add_none: + filtered += [None] * add_none + if to_expand: + filtered += children[i].children + else: + filtered.append(children[i]) + + if self.append_none: + filtered += [None] * self.append_none + + return self.node_builder(filtered) + + +class ChildFilterLALR(ChildFilter): + #-- + + def __call__(self, children): + filtered = [] + for i, to_expand, add_none in self.to_include: + if add_none: + filtered += [None] * add_none + if to_expand: + if filtered: + filtered += children[i].children + else: ## + + filtered = children[i].children + else: + filtered.append(children[i]) + + if self.append_none: + filtered += [None] * self.append_none + + return self.node_builder(filtered) + + +class ChildFilterLALR_NoPlaceholders(ChildFilter): + #-- + def __init__(self, to_include, node_builder): + self.node_builder = node_builder + self.to_include = to_include + + def __call__(self, children): + filtered = [] + for i, to_expand in self.to_include: + if to_expand: + if filtered: + filtered += children[i].children + else: ## + + filtered = children[i].children + else: + filtered.append(children[i]) + return self.node_builder(filtered) + + +def _should_expand(sym): + return not sym.is_term and sym.name.startswith('_') + + +def maybe_create_child_filter(expansion, keep_all_tokens, ambiguous, _empty_indices: List[bool]): + ## + + if _empty_indices: + assert _empty_indices.count(False) == len(expansion) + s = ''.join(str(int(b)) for b in _empty_indices) + empty_indices = [len(ones) for ones in s.split('0')] + assert len(empty_indices) == len(expansion)+1, (empty_indices, len(expansion)) + else: + empty_indices = [0] * (len(expansion)+1) + + to_include = [] + nones_to_add = 0 + for i, sym in enumerate(expansion): + nones_to_add += empty_indices[i] + if keep_all_tokens or not (sym.is_term and sym.filter_out): + to_include.append((i, _should_expand(sym), nones_to_add)) + nones_to_add = 0 + + nones_to_add += empty_indices[len(expansion)] + + if _empty_indices or len(to_include) < len(expansion) or any(to_expand for i, to_expand,_ in to_include): + if _empty_indices or ambiguous: + return partial(ChildFilter if ambiguous else ChildFilterLALR, to_include, nones_to_add) + else: + ## + + return partial(ChildFilterLALR_NoPlaceholders, [(i, x) for i,x,_ in to_include]) + + +class AmbiguousExpander: + #-- + def __init__(self, to_expand, tree_class, node_builder): + self.node_builder = node_builder + self.tree_class = tree_class + self.to_expand = to_expand + + def __call__(self, children): + def _is_ambig_tree(t): + return hasattr(t, 'data') and t.data == '_ambig' + + ## + + ## + + ## + + ## + + ambiguous = [] + for i, child in enumerate(children): + if _is_ambig_tree(child): + if i in self.to_expand: + ambiguous.append(i) + + child.expand_kids_by_data('_ambig') + + if not ambiguous: + return self.node_builder(children) + + expand = [child.children if i in ambiguous else (child,) for i, child in enumerate(children)] + return self.tree_class('_ambig', [self.node_builder(list(f)) for f in product(*expand)]) + + +def maybe_create_ambiguous_expander(tree_class, expansion, keep_all_tokens): + to_expand = [i for i, sym in enumerate(expansion) + if keep_all_tokens or ((not (sym.is_term and sym.filter_out)) and _should_expand(sym))] + if to_expand: + return partial(AmbiguousExpander, to_expand, tree_class) + + +class AmbiguousIntermediateExpander: + #-- + + def __init__(self, tree_class, node_builder): + self.node_builder = node_builder + self.tree_class = tree_class + + def __call__(self, children): + def _is_iambig_tree(child): + return hasattr(child, 'data') and child.data == '_iambig' + + def _collapse_iambig(children): + #-- + + ## + + ## + + if children and _is_iambig_tree(children[0]): + iambig_node = children[0] + result = [] + for grandchild in iambig_node.children: + collapsed = _collapse_iambig(grandchild.children) + if collapsed: + for child in collapsed: + child.children += children[1:] + result += collapsed + else: + new_tree = self.tree_class('_inter', grandchild.children + children[1:]) + result.append(new_tree) + return result + + collapsed = _collapse_iambig(children) + if collapsed: + processed_nodes = [self.node_builder(c.children) for c in collapsed] + return self.tree_class('_ambig', processed_nodes) + + return self.node_builder(children) + + + +def inplace_transformer(func): + @wraps(func) + def f(children): + ## + + tree = Tree(func.__name__, children) + return func(tree) + return f + + +def apply_visit_wrapper(func, name, wrapper): + if wrapper is _vargs_meta or wrapper is _vargs_meta_inline: + raise NotImplementedError("Meta args not supported for internal transformer; use YourTransformer().transform(parser.parse()) instead") + + @wraps(func) + def f(children): + return wrapper(func, name, children, None) + return f + + +class ParseTreeBuilder: + def __init__(self, rules, tree_class, propagate_positions=False, ambiguous=False, maybe_placeholders=False): + self.tree_class = tree_class + self.propagate_positions = propagate_positions + self.ambiguous = ambiguous + self.maybe_placeholders = maybe_placeholders + + self.rule_builders = list(self._init_builders(rules)) + + def _init_builders(self, rules): + propagate_positions = make_propagate_positions(self.propagate_positions) + + for rule in rules: + options = rule.options + keep_all_tokens = options.keep_all_tokens + expand_single_child = options.expand1 + + wrapper_chain = list(filter(None, [ + (expand_single_child and not rule.alias) and ExpandSingleChild, + maybe_create_child_filter(rule.expansion, keep_all_tokens, self.ambiguous, options.empty_indices if self.maybe_placeholders else None), + propagate_positions, + self.ambiguous and maybe_create_ambiguous_expander(self.tree_class, rule.expansion, keep_all_tokens), + self.ambiguous and partial(AmbiguousIntermediateExpander, self.tree_class) + ])) + + yield rule, wrapper_chain + + def create_callback(self, transformer=None): + callbacks = {} + + default_handler = getattr(transformer, '__default__', None) + if default_handler: + def default_callback(data, children): + return default_handler(data, children, None) + else: + default_callback = self.tree_class + + for rule, wrapper_chain in self.rule_builders: + + user_callback_name = rule.alias or rule.options.template_source or rule.origin.name + try: + f = getattr(transformer, user_callback_name) + wrapper = getattr(f, 'visit_wrapper', None) + if wrapper is not None: + f = apply_visit_wrapper(f, user_callback_name, wrapper) + elif isinstance(transformer, Transformer_InPlace): + f = inplace_transformer(f) + except AttributeError: + f = partial(default_callback, user_callback_name) + + for w in wrapper_chain: + f = w(f) + + if rule in callbacks: + raise GrammarError("Rule '%s' already exists" % (rule,)) + + callbacks[rule] = f + + return callbacks + + + +class Action: + def __init__(self, name): + self.name = name + def __str__(self): + return self.name + def __repr__(self): + return str(self) + +Shift = Action('Shift') +Reduce = Action('Reduce') + +StateT = TypeVar("StateT") + +class ParseTableBase(Generic[StateT]): + states: Dict[StateT, Dict[str, Tuple]] + start_states: Dict[str, StateT] + end_states: Dict[str, StateT] + + def __init__(self, states, start_states, end_states): + self.states = states + self.start_states = start_states + self.end_states = end_states + + def serialize(self, memo): + tokens = Enumerator() + + states = { + state: {tokens.get(token): ((1, arg.serialize(memo)) if action is Reduce else (0, arg)) + for token, (action, arg) in actions.items()} + for state, actions in self.states.items() + } + + return { + 'tokens': tokens.reversed(), + 'states': states, + 'start_states': self.start_states, + 'end_states': self.end_states, + } + + @classmethod + def deserialize(cls, data, memo): + tokens = data['tokens'] + states = { + state: {tokens[token]: ((Reduce, Rule.deserialize(arg, memo)) if action==1 else (Shift, arg)) + for token, (action, arg) in actions.items()} + for state, actions in data['states'].items() + } + return cls(states, data['start_states'], data['end_states']) + +class ParseTable(ParseTableBase['State']): + #-- + pass + + +class IntParseTable(ParseTableBase[int]): + #-- + + @classmethod + def from_ParseTable(cls, parse_table: ParseTable): + enum = list(parse_table.states) + state_to_idx: Dict['State', int] = {s:i for i,s in enumerate(enum)} + int_states = {} + + for s, la in parse_table.states.items(): + la = {k:(v[0], state_to_idx[v[1]]) if v[0] is Shift else v + for k,v in la.items()} + int_states[ state_to_idx[s] ] = la + + + start_states = {start:state_to_idx[s] for start, s in parse_table.start_states.items()} + end_states = {start:state_to_idx[s] for start, s in parse_table.end_states.items()} + return cls(int_states, start_states, end_states) + + + +class ParseConf(Generic[StateT]): + __slots__ = 'parse_table', 'callbacks', 'start', 'start_state', 'end_state', 'states' + + parse_table: ParseTableBase[StateT] + callbacks: ParserCallbacks + start: str + + start_state: StateT + end_state: StateT + states: Dict[StateT, Dict[str, tuple]] + + def __init__(self, parse_table: ParseTableBase[StateT], callbacks: ParserCallbacks, start: str): + self.parse_table = parse_table + + self.start_state = self.parse_table.start_states[start] + self.end_state = self.parse_table.end_states[start] + self.states = self.parse_table.states + + self.callbacks = callbacks + self.start = start + +class ParserState(Generic[StateT]): + __slots__ = 'parse_conf', 'lexer', 'state_stack', 'value_stack' + + parse_conf: ParseConf[StateT] + lexer: LexerThread + state_stack: List[StateT] + value_stack: list + + def __init__(self, parse_conf: ParseConf[StateT], lexer: LexerThread, state_stack=None, value_stack=None): + self.parse_conf = parse_conf + self.lexer = lexer + self.state_stack = state_stack or [self.parse_conf.start_state] + self.value_stack = value_stack or [] + + @property + def position(self) -> StateT: + return self.state_stack[-1] + + ## + + def __eq__(self, other) -> bool: + if not isinstance(other, ParserState): + return NotImplemented + return len(self.state_stack) == len(other.state_stack) and self.position == other.position + + def __copy__(self): + return self.copy() + + def copy(self, deepcopy_values=True) -> 'ParserState[StateT]': + return type(self)( + self.parse_conf, + self.lexer, ## + + copy(self.state_stack), + deepcopy(self.value_stack) if deepcopy_values else copy(self.value_stack), + ) + + def feed_token(self, token: Token, is_end=False) -> Any: + state_stack = self.state_stack + value_stack = self.value_stack + states = self.parse_conf.states + end_state = self.parse_conf.end_state + callbacks = self.parse_conf.callbacks + + while True: + state = state_stack[-1] + try: + action, arg = states[state][token.type] + except KeyError: + expected = {s for s in states[state].keys() if s.isupper()} + raise UnexpectedToken(token, expected, state=self, interactive_parser=None) + + assert arg != end_state + + if action is Shift: + ## + + assert not is_end + state_stack.append(arg) + value_stack.append(token if token.type not in callbacks else callbacks[token.type](token)) + return + else: + ## + + rule = arg + size = len(rule.expansion) + if size: + s = value_stack[-size:] + del state_stack[-size:] + del value_stack[-size:] + else: + s = [] + + value = callbacks[rule](s) if callbacks else s + + _action, new_state = states[state_stack[-1]][rule.origin.name] + assert _action is Shift + state_stack.append(new_state) + value_stack.append(value) + + if is_end and state_stack[-1] == end_state: + return value_stack[-1] + + +class LALR_Parser(Serialize): + def __init__(self, parser_conf: ParserConf, debug: bool=False, strict: bool=False): + analysis = LALR_Analyzer(parser_conf, debug=debug, strict=strict) + analysis.compute_lalr() + callbacks = parser_conf.callbacks + + self._parse_table = analysis.parse_table + self.parser_conf = parser_conf + self.parser = _Parser(analysis.parse_table, callbacks, debug) + + @classmethod + def deserialize(cls, data, memo, callbacks, debug=False): + inst = cls.__new__(cls) + inst._parse_table = IntParseTable.deserialize(data, memo) + inst.parser = _Parser(inst._parse_table, callbacks, debug) + return inst + + def serialize(self, memo: Any = None) -> Dict[str, Any]: + return self._parse_table.serialize(memo) + + def parse_interactive(self, lexer: LexerThread, start: str): + return self.parser.parse(lexer, start, start_interactive=True) + + def parse(self, lexer, start, on_error=None): + try: + return self.parser.parse(lexer, start) + except UnexpectedInput as e: + if on_error is None: + raise + + while True: + if isinstance(e, UnexpectedCharacters): + s = e.interactive_parser.lexer_thread.state + p = s.line_ctr.char_pos + + if not on_error(e): + raise e + + if isinstance(e, UnexpectedCharacters): + ## + + if p == s.line_ctr.char_pos: + s.line_ctr.feed(s.text.text[p:p+1]) + + try: + return e.interactive_parser.resume_parse() + except UnexpectedToken as e2: + if (isinstance(e, UnexpectedToken) + and e.token.type == e2.token.type == '$END' + and e.interactive_parser == e2.interactive_parser): + ## + + raise e2 + e = e2 + except UnexpectedCharacters as e2: + e = e2 + + +class _Parser: + parse_table: ParseTableBase + callbacks: ParserCallbacks + debug: bool + + def __init__(self, parse_table: ParseTableBase, callbacks: ParserCallbacks, debug: bool=False): + self.parse_table = parse_table + self.callbacks = callbacks + self.debug = debug + + def parse(self, lexer: LexerThread, start: str, value_stack=None, state_stack=None, start_interactive=False): + parse_conf = ParseConf(self.parse_table, self.callbacks, start) + parser_state = ParserState(parse_conf, lexer, state_stack, value_stack) + if start_interactive: + return InteractiveParser(self, parser_state, parser_state.lexer) + return self.parse_from_state(parser_state) + + + def parse_from_state(self, state: ParserState, last_token: Optional[Token]=None): + #-- + try: + token = last_token + for token in state.lexer.lex(state): + assert token is not None + state.feed_token(token) + + end_token = Token.new_borrow_pos('$END', '', token) if token else Token('$END', '', 0, 1, 1) + return state.feed_token(end_token, True) + except UnexpectedInput as e: + try: + e.interactive_parser = InteractiveParser(self, state, state.lexer) + except NameError: + pass + raise e + except Exception as e: + if self.debug: + print("") + print("STATE STACK DUMP") + print("----------------") + for i, s in enumerate(state.state_stack): + print('%d)' % i , s) + print("") + + raise + + +class InteractiveParser: + #-- + def __init__(self, parser, parser_state: ParserState, lexer_thread: LexerThread): + self.parser = parser + self.parser_state = parser_state + self.lexer_thread = lexer_thread + self.result = None + + @property + def lexer_state(self) -> LexerThread: + warnings.warn("lexer_state will be removed in subsequent releases. Use lexer_thread instead.", DeprecationWarning) + return self.lexer_thread + + def feed_token(self, token: Token): + #-- + return self.parser_state.feed_token(token, token.type == '$END') + + def iter_parse(self) -> Iterator[Token]: + #-- + for token in self.lexer_thread.lex(self.parser_state): + yield token + self.result = self.feed_token(token) + + def exhaust_lexer(self) -> List[Token]: + #-- + return list(self.iter_parse()) + + + def feed_eof(self, last_token=None): + #-- + eof = Token.new_borrow_pos('$END', '', last_token) if last_token is not None else self.lexer_thread._Token('$END', '', 0, 1, 1) + return self.feed_token(eof) + + + def __copy__(self): + #-- + return self.copy() + + def copy(self, deepcopy_values=True): + return type(self)( + self.parser, + self.parser_state.copy(deepcopy_values=deepcopy_values), + copy(self.lexer_thread), + ) + + def __eq__(self, other): + if not isinstance(other, InteractiveParser): + return False + + return self.parser_state == other.parser_state and self.lexer_thread == other.lexer_thread + + def as_immutable(self): + #-- + p = copy(self) + return ImmutableInteractiveParser(p.parser, p.parser_state, p.lexer_thread) + + def pretty(self): + #-- + out = ["Parser choices:"] + for k, v in self.choices().items(): + out.append('\t- %s -> %r' % (k, v)) + out.append('stack size: %s' % len(self.parser_state.state_stack)) + return '\n'.join(out) + + def choices(self): + #-- + return self.parser_state.parse_conf.parse_table.states[self.parser_state.position] + + def accepts(self): + #-- + accepts = set() + conf_no_callbacks = copy(self.parser_state.parse_conf) + ## + + ## + + conf_no_callbacks.callbacks = {} + for t in self.choices(): + if t.isupper(): ## + + new_cursor = self.copy(deepcopy_values=False) + new_cursor.parser_state.parse_conf = conf_no_callbacks + try: + new_cursor.feed_token(self.lexer_thread._Token(t, '')) + except UnexpectedToken: + pass + else: + accepts.add(t) + return accepts + + def resume_parse(self): + #-- + return self.parser.parse_from_state(self.parser_state, last_token=self.lexer_thread.state.last_token) + + + +class ImmutableInteractiveParser(InteractiveParser): + #-- + + result = None + + def __hash__(self): + return hash((self.parser_state, self.lexer_thread)) + + def feed_token(self, token): + c = copy(self) + c.result = InteractiveParser.feed_token(c, token) + return c + + def exhaust_lexer(self): + #-- + cursor = self.as_mutable() + cursor.exhaust_lexer() + return cursor.as_immutable() + + def as_mutable(self): + #-- + p = copy(self) + return InteractiveParser(p.parser, p.parser_state, p.lexer_thread) + + + +def _wrap_lexer(lexer_class): + future_interface = getattr(lexer_class, '__future_interface__', 0) + if future_interface == 2: + return lexer_class + elif future_interface == 1: + class CustomLexerWrapper1(Lexer): + def __init__(self, lexer_conf): + self.lexer = lexer_class(lexer_conf) + def lex(self, lexer_state, parser_state): + if isinstance(lexer_state.text, TextSlice) and not lexer_state.text.is_complete_text(): + raise TypeError("Interface=1 Custom Lexer don't support TextSlice") + lexer_state.text = lexer_state.text + return self.lexer.lex(lexer_state, parser_state) + return CustomLexerWrapper1 + elif future_interface == 0: + class CustomLexerWrapper0(Lexer): + def __init__(self, lexer_conf): + self.lexer = lexer_class(lexer_conf) + + def lex(self, lexer_state, parser_state): + if isinstance(lexer_state.text, TextSlice): + if not lexer_state.text.is_complete_text(): + raise TypeError("Interface=0 Custom Lexer don't support TextSlice") + return self.lexer.lex(lexer_state.text.text) + return self.lexer.lex(lexer_state.text) + return CustomLexerWrapper0 + else: + raise ValueError(f"Unknown __future_interface__ value {future_interface}, integer 0-2 expected") + + +def _deserialize_parsing_frontend(data, memo, lexer_conf, callbacks, options): + parser_conf = ParserConf.deserialize(data['parser_conf'], memo) + cls = (options and options._plugins.get('LALR_Parser')) or LALR_Parser + parser = cls.deserialize(data['parser'], memo, callbacks, options.debug) + parser_conf.callbacks = callbacks + return ParsingFrontend(lexer_conf, parser_conf, options, parser=parser) + + +_parser_creators: 'Dict[str, Callable[[LexerConf, Any, Any], Any]]' = {} + + +class ParsingFrontend(Serialize): + __serialize_fields__ = 'lexer_conf', 'parser_conf', 'parser' + + lexer_conf: LexerConf + parser_conf: ParserConf + options: Any + + def __init__(self, lexer_conf: LexerConf, parser_conf: ParserConf, options, parser=None): + self.parser_conf = parser_conf + self.lexer_conf = lexer_conf + self.options = options + + ## + + if parser: ## + + self.parser = parser + else: + create_parser = _parser_creators.get(parser_conf.parser_type) + assert create_parser is not None, "{} is not supported in standalone mode".format( + parser_conf.parser_type + ) + self.parser = create_parser(lexer_conf, parser_conf, options) + + ## + + lexer_type = lexer_conf.lexer_type + self.skip_lexer = False + if lexer_type in ('dynamic', 'dynamic_complete'): + assert lexer_conf.postlex is None + self.skip_lexer = True + return + + if isinstance(lexer_type, type): + assert issubclass(lexer_type, Lexer) + self.lexer = _wrap_lexer(lexer_type)(lexer_conf) + elif isinstance(lexer_type, str): + create_lexer = { + 'basic': create_basic_lexer, + 'contextual': create_contextual_lexer, + }[lexer_type] + self.lexer = create_lexer(lexer_conf, self.parser, lexer_conf.postlex, options) + else: + raise TypeError("Bad value for lexer_type: {lexer_type}") + + if lexer_conf.postlex: + self.lexer = PostLexConnector(self.lexer, lexer_conf.postlex) + + def _verify_start(self, start=None): + if start is None: + start_decls = self.parser_conf.start + if len(start_decls) > 1: + raise ConfigurationError("Lark initialized with more than 1 possible start rule. Must specify which start rule to parse", start_decls) + start ,= start_decls + elif start not in self.parser_conf.start: + raise ConfigurationError("Unknown start rule %s. Must be one of %r" % (start, self.parser_conf.start)) + return start + + def _make_lexer_thread(self, text: Optional[LarkInput]) -> Union[LarkInput, LexerThread, None]: + cls = (self.options and self.options._plugins.get('LexerThread')) or LexerThread + if self.skip_lexer: + return text + if text is None: + return cls(self.lexer, None) + if isinstance(text, (str, bytes, TextSlice)): + return cls.from_text(self.lexer, text) + return cls.from_custom_input(self.lexer, text) + + def parse(self, text: Optional[LarkInput], start=None, on_error=None): + if self.lexer_conf.lexer_type in ("dynamic", "dynamic_complete"): + if isinstance(text, TextSlice) and not text.is_complete_text(): + raise TypeError(f"Lexer {self.lexer_conf.lexer_type} does not support text slices.") + + chosen_start = self._verify_start(start) + kw = {} if on_error is None else {'on_error': on_error} + stream = self._make_lexer_thread(text) + return self.parser.parse(stream, chosen_start, **kw) + + def parse_interactive(self, text: Optional[TextOrSlice]=None, start=None): + ## + + ## + + chosen_start = self._verify_start(start) + if self.parser_conf.parser_type != 'lalr': + raise ConfigurationError("parse_interactive() currently only works with parser='lalr' ") + stream = self._make_lexer_thread(text) + return self.parser.parse_interactive(stream, chosen_start) + + +def _validate_frontend_args(parser, lexer) -> None: + assert_config(parser, ('lalr', 'earley', 'cyk')) + if not isinstance(lexer, type): ## + + expected = { + 'lalr': ('basic', 'contextual'), + 'earley': ('basic', 'dynamic', 'dynamic_complete'), + 'cyk': ('basic', ), + }[parser] + assert_config(lexer, expected, 'Parser %r does not support lexer %%r, expected one of %%s' % parser) + + +def _get_lexer_callbacks(transformer, terminals): + result = {} + for terminal in terminals: + callback = getattr(transformer, terminal.name, None) + if callback is not None: + result[terminal.name] = callback + return result + +class PostLexConnector: + def __init__(self, lexer, postlexer): + self.lexer = lexer + self.postlexer = postlexer + + def lex(self, lexer_state, parser_state): + i = self.lexer.lex(lexer_state, parser_state) + return self.postlexer.process(i) + + + +def create_basic_lexer(lexer_conf, parser, postlex, options) -> BasicLexer: + cls = (options and options._plugins.get('BasicLexer')) or BasicLexer + return cls(lexer_conf) + +def create_contextual_lexer(lexer_conf: LexerConf, parser, postlex, options) -> ContextualLexer: + cls = (options and options._plugins.get('ContextualLexer')) or ContextualLexer + parse_table: ParseTableBase[int] = parser._parse_table + states: Dict[int, Collection[str]] = {idx:list(t.keys()) for idx, t in parse_table.states.items()} + always_accept: Collection[str] = postlex.always_accept if postlex else () + return cls(lexer_conf, states, always_accept=always_accept) + +def create_lalr_parser(lexer_conf: LexerConf, parser_conf: ParserConf, options=None) -> LALR_Parser: + debug = options.debug if options else False + strict = options.strict if options else False + cls = (options and options._plugins.get('LALR_Parser')) or LALR_Parser + return cls(parser_conf, debug=debug, strict=strict) + +_parser_creators['lalr'] = create_lalr_parser + + + + +class PostLex(ABC): + @abstractmethod + def process(self, stream: Iterator[Token]) -> Iterator[Token]: + return stream + + always_accept: Iterable[str] = () + +class LarkOptions(Serialize): + #-- + + start: List[str] + debug: bool + strict: bool + transformer: 'Optional[Transformer]' + propagate_positions: Union[bool, str] + maybe_placeholders: bool + cache: Union[bool, str] + cache_grammar: bool + regex: bool + g_regex_flags: int + keep_all_tokens: bool + tree_class: Optional[Callable[[str, List], Any]] + parser: _ParserArgType + lexer: _LexerArgType + ambiguity: 'Literal["auto", "resolve", "explicit", "forest"]' + postlex: Optional[PostLex] + priority: 'Optional[Literal["auto", "normal", "invert"]]' + lexer_callbacks: Dict[str, Callable[[Token], Token]] + use_bytes: bool + ordered_sets: bool + edit_terminals: Optional[Callable[[TerminalDef], TerminalDef]] + import_paths: 'List[Union[str, Callable[[Union[None, str, PackageResource], str], Tuple[str, str]]]]' + source_path: Optional[str] + + OPTIONS_DOC = r""" + **=== General Options ===** + + start + The start symbol. Either a string, or a list of strings for multiple possible starts (Default: "start") + debug + Display debug information and extra warnings. Use only when debugging (Default: ``False``) + When used with Earley, it generates a forest graph as "sppf.png", if 'dot' is installed. + strict + Throw an exception on any potential ambiguity, including shift/reduce conflicts, and regex collisions. + transformer + Applies the transformer to every parse tree (equivalent to applying it after the parse, but faster) + propagate_positions + Propagates positional attributes into the 'meta' attribute of all tree branches. + Sets attributes: (line, column, end_line, end_column, start_pos, end_pos, + container_line, container_column, container_end_line, container_end_column) + Accepts ``False``, ``True``, or a callable, which will filter which nodes to ignore when propagating. + maybe_placeholders + When ``True``, the ``[]`` operator returns ``None`` when not matched. + When ``False``, ``[]`` behaves like the ``?`` operator, and returns no value at all. + (default= ``True``) + cache + Cache the results of the Lark grammar analysis, for x2 to x3 faster loading. LALR only for now. + + - When ``False``, does nothing (default) + - When ``True``, caches to a temporary file in the local directory + - When given a string, caches to the path pointed by the string + cache_grammar + For use with ``cache`` option. When ``True``, the unanalyzed grammar is also included in the cache. + Useful for classes that require the ``Lark.grammar`` to be present (e.g. Reconstructor). + (default= ``False``) + regex + When True, uses the ``regex`` module instead of the stdlib ``re``. + g_regex_flags + Flags that are applied to all terminals (both regex and strings) + keep_all_tokens + Prevent the tree builder from automagically removing "punctuation" tokens (Default: ``False``) + tree_class + Lark will produce trees comprised of instances of this class instead of the default ``lark.Tree``. + + **=== Algorithm Options ===** + + parser + Decides which parser engine to use. Accepts "earley" or "lalr". (Default: "earley"). + (there is also a "cyk" option for legacy) + lexer + Decides whether or not to use a lexer stage + + - "auto" (default): Choose for me based on the parser + - "basic": Use a basic lexer + - "contextual": Stronger lexer (only works with parser="lalr") + - "dynamic": Flexible and powerful (only with parser="earley") + - "dynamic_complete": Same as dynamic, but tries *every* variation of tokenizing possible. + ambiguity + Decides how to handle ambiguity in the parse. Only relevant if parser="earley" + + - "resolve": The parser will automatically choose the simplest derivation + (it chooses consistently: greedy for tokens, non-greedy for rules) + - "explicit": The parser will return all derivations wrapped in "_ambig" tree nodes (i.e. a forest). + - "forest": The parser will return the root of the shared packed parse forest. + + **=== Misc. / Domain Specific Options ===** + + postlex + Lexer post-processing (Default: ``None``) Only works with the basic and contextual lexers. + priority + How priorities should be evaluated - "auto", ``None``, "normal", "invert" (Default: "auto") + lexer_callbacks + Dictionary of callbacks for the lexer. May alter tokens during lexing. Use with caution. + use_bytes + Accept an input of type ``bytes`` instead of ``str``. + ordered_sets + Should Earley use ordered-sets to achieve stable output (~10% slower than regular sets. Default: True) + edit_terminals + A callback for editing the terminals before parse. + import_paths + A List of either paths or loader functions to specify from where grammars are imported + source_path + Override the source of from where the grammar was loaded. Useful for relative imports and unconventional grammar loading + **=== End of Options ===** + """ + if __doc__: + __doc__ += OPTIONS_DOC + + + ## + + ## + + ## + + ## + + ## + + ## + + _defaults: Dict[str, Any] = { + 'debug': False, + 'strict': False, + 'keep_all_tokens': False, + 'tree_class': None, + 'cache': False, + 'cache_grammar': False, + 'postlex': None, + 'parser': 'earley', + 'lexer': 'auto', + 'transformer': None, + 'start': 'start', + 'priority': 'auto', + 'ambiguity': 'auto', + 'regex': False, + 'propagate_positions': False, + 'lexer_callbacks': {}, + 'maybe_placeholders': True, + 'edit_terminals': None, + 'g_regex_flags': 0, + 'use_bytes': False, + 'ordered_sets': True, + 'import_paths': [], + 'source_path': None, + '_plugins': {}, + } + + def __init__(self, options_dict: Dict[str, Any]) -> None: + o = dict(options_dict) + + options = {} + for name, default in self._defaults.items(): + if name in o: + value = o.pop(name) + if isinstance(default, bool) and name not in ('cache', 'use_bytes', 'propagate_positions'): + value = bool(value) + else: + value = default + + options[name] = value + + if isinstance(options['start'], str): + options['start'] = [options['start']] + + self.__dict__['options'] = options + + + assert_config(self.parser, ('earley', 'lalr', 'cyk', None)) + + if self.parser == 'earley' and self.transformer: + raise ConfigurationError('Cannot specify an embedded transformer when using the Earley algorithm. ' + 'Please use your transformer on the resulting parse tree, or use a different algorithm (i.e. LALR)') + + if self.cache_grammar and not self.cache: + raise ConfigurationError('cache_grammar cannot be set when cache is disabled') + + if o: + raise ConfigurationError("Unknown options: %s" % o.keys()) + + def __getattr__(self, name: str) -> Any: + try: + return self.__dict__['options'][name] + except KeyError as e: + raise AttributeError(e) + + def __setattr__(self, name: str, value: str) -> None: + assert_config(name, self.options.keys(), "%r isn't a valid option. Expected one of: %s") + self.options[name] = value + + def serialize(self, memo = None) -> Dict[str, Any]: + return self.options + + @classmethod + def deserialize(cls, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]]) -> "LarkOptions": + return cls(data) + + +## + +## + +_LOAD_ALLOWED_OPTIONS = {'postlex', 'transformer', 'lexer_callbacks', 'use_bytes', 'debug', 'g_regex_flags', 'regex', 'propagate_positions', 'tree_class', '_plugins'} + +_VALID_PRIORITY_OPTIONS = ('auto', 'normal', 'invert', None) +_VALID_AMBIGUITY_OPTIONS = ('auto', 'resolve', 'explicit', 'forest') + + +_T = TypeVar('_T', bound="Lark") + +class Lark(Serialize): + #-- + + source_path: str + source_grammar: str + grammar: 'Grammar' + options: LarkOptions + lexer: Lexer + parser: 'ParsingFrontend' + terminals: Collection[TerminalDef] + + __serialize_fields__ = ['parser', 'rules', 'options'] + + def __init__(self, grammar: 'Union[Grammar, str, IO[str]]', **options) -> None: + self.options = LarkOptions(options) + re_module: types.ModuleType + + ## + + if self.options.cache_grammar: + self.__serialize_fields__ = self.__serialize_fields__ + ['grammar'] + + ## + + use_regex = self.options.regex + if use_regex: + if _has_regex: + re_module = regex + else: + raise ImportError('`regex` module must be installed if calling `Lark(regex=True)`.') + else: + re_module = re + + ## + + if self.options.source_path is None: + try: + self.source_path = grammar.name ## + + except AttributeError: + self.source_path = '' + else: + self.source_path = self.options.source_path + + ## + + try: + read = grammar.read ## + + except AttributeError: + pass + else: + grammar = read() + + cache_fn = None + cache_sha256 = None + if isinstance(grammar, str): + self.source_grammar = grammar + if self.options.use_bytes: + if not grammar.isascii(): + raise ConfigurationError("Grammar must be ascii only, when use_bytes=True") + + if self.options.cache: + if self.options.parser != 'lalr': + raise ConfigurationError("cache only works with parser='lalr' for now") + + unhashable = ('transformer', 'postlex', 'lexer_callbacks', 'edit_terminals', '_plugins') + options_str = ''.join(k+str(v) for k, v in options.items() if k not in unhashable) + from . import __version__ + s = grammar + options_str + __version__ + str(sys.version_info[:2]) + cache_sha256 = sha256_digest(s) + + if isinstance(self.options.cache, str): + cache_fn = self.options.cache + else: + if self.options.cache is not True: + raise ConfigurationError("cache argument must be bool or str") + + try: + username = getpass.getuser() + except Exception: + ## + + ## + + ## + + username = "unknown" + + + cache_fn = tempfile.gettempdir() + "/.lark_%s_%s_%s_%s_%s.tmp" % ( + "cache_grammar" if self.options.cache_grammar else "cache", username, cache_sha256, *sys.version_info[:2]) + + old_options = self.options + try: + with FS.open(cache_fn, 'rb') as f: + logger.debug('Loading grammar from cache: %s', cache_fn) + ## + + for name in (set(options) - _LOAD_ALLOWED_OPTIONS): + del options[name] + file_sha256 = f.readline().rstrip(b'\n') + cached_used_files = pickle.load(f) # nosec B301 noqa - We're not using this functionality. + if file_sha256 == cache_sha256.encode('utf8') and verify_used_files(cached_used_files): + cached_parser_data = pickle.load(f) # nosec B301 noqa - We're not using this functionality. + self._load(cached_parser_data, **options) + return + except FileNotFoundError: + ## + + pass + except Exception: ## + + logger.exception("Failed to load Lark from cache: %r. We will try to carry on.", cache_fn) + + ## + + ## + + self.options = old_options + + + ## + + self.grammar, used_files = load_grammar(grammar, self.source_path, self.options.import_paths, self.options.keep_all_tokens) + else: + assert isinstance(grammar, Grammar) + self.grammar = grammar + + + if self.options.lexer == 'auto': + if self.options.parser == 'lalr': + self.options.lexer = 'contextual' + elif self.options.parser == 'earley': + if self.options.postlex is not None: + logger.info("postlex can't be used with the dynamic lexer, so we use 'basic' instead. " + "Consider using lalr with contextual instead of earley") + self.options.lexer = 'basic' + else: + self.options.lexer = 'dynamic' + elif self.options.parser == 'cyk': + self.options.lexer = 'basic' + else: + assert False, self.options.parser + lexer = self.options.lexer + if isinstance(lexer, type): + assert issubclass(lexer, Lexer) ## + + else: + assert_config(lexer, ('basic', 'contextual', 'dynamic', 'dynamic_complete')) + if self.options.postlex is not None and 'dynamic' in lexer: + raise ConfigurationError("Can't use postlex with a dynamic lexer. Use basic or contextual instead") + + if self.options.ambiguity == 'auto': + if self.options.parser == 'earley': + self.options.ambiguity = 'resolve' + else: + assert_config(self.options.parser, ('earley', 'cyk'), "%r doesn't support disambiguation. Use one of these parsers instead: %s") + + if self.options.priority == 'auto': + self.options.priority = 'normal' + + if self.options.priority not in _VALID_PRIORITY_OPTIONS: + raise ConfigurationError("invalid priority option: %r. Must be one of %r" % (self.options.priority, _VALID_PRIORITY_OPTIONS)) + if self.options.ambiguity not in _VALID_AMBIGUITY_OPTIONS: + raise ConfigurationError("invalid ambiguity option: %r. Must be one of %r" % (self.options.ambiguity, _VALID_AMBIGUITY_OPTIONS)) + + if self.options.parser is None: + terminals_to_keep = '*' ## + + elif self.options.postlex is not None: + terminals_to_keep = set(self.options.postlex.always_accept) + else: + terminals_to_keep = set() + + ## + + self.terminals, self.rules, self.ignore_tokens = self.grammar.compile(self.options.start, terminals_to_keep) + + if self.options.edit_terminals: + for t in self.terminals: + self.options.edit_terminals(t) + + self._terminals_dict = {t.name: t for t in self.terminals} + + ## + + if self.options.priority == 'invert': + for rule in self.rules: + if rule.options.priority is not None: + rule.options.priority = -rule.options.priority + for term in self.terminals: + term.priority = -term.priority + ## + + ## + + ## + + elif self.options.priority is None: + for rule in self.rules: + if rule.options.priority is not None: + rule.options.priority = None + for term in self.terminals: + term.priority = 0 + + ## + + self.lexer_conf = LexerConf( + self.terminals, re_module, self.ignore_tokens, self.options.postlex, + self.options.lexer_callbacks, self.options.g_regex_flags, use_bytes=self.options.use_bytes, strict=self.options.strict + ) + + if self.options.parser: + self.parser = self._build_parser() + elif lexer: + self.lexer = self._build_lexer() + + if cache_fn: + logger.debug('Saving grammar to cache: %s', cache_fn) + try: + with FS.open(cache_fn, 'wb') as f: + assert cache_sha256 is not None + f.write(cache_sha256.encode('utf8') + b'\n') + pickle.dump(used_files, f) # nosec B301 noqa - We're not using this functionality. + self.save(f, _LOAD_ALLOWED_OPTIONS) + except IOError as e: + logger.exception("Failed to save Lark to cache: %r.", cache_fn, e) + + if __doc__: + __doc__ += "\n\n" + LarkOptions.OPTIONS_DOC + + def _build_lexer(self, dont_ignore: bool=False) -> BasicLexer: + lexer_conf = self.lexer_conf + if dont_ignore: + from copy import copy + lexer_conf = copy(lexer_conf) + lexer_conf.ignore = () + return BasicLexer(lexer_conf) + + def _prepare_callbacks(self) -> None: + self._callbacks = {} + ## + + if self.options.ambiguity != 'forest': + self._parse_tree_builder = ParseTreeBuilder( + self.rules, + self.options.tree_class or Tree, + self.options.propagate_positions, + self.options.parser != 'lalr' and self.options.ambiguity == 'explicit', + self.options.maybe_placeholders + ) + self._callbacks = self._parse_tree_builder.create_callback(self.options.transformer) + self._callbacks.update(_get_lexer_callbacks(self.options.transformer, self.terminals)) + + def _build_parser(self) -> "ParsingFrontend": + self._prepare_callbacks() + _validate_frontend_args(self.options.parser, self.options.lexer) + parser_conf = ParserConf(self.rules, self._callbacks, self.options.start) + return _construct_parsing_frontend( + self.options.parser, + self.options.lexer, + self.lexer_conf, + parser_conf, + options=self.options + ) + + def save(self, f, exclude_options: Collection[str] = ()) -> None: + #-- + if self.options.parser != 'lalr': + raise NotImplementedError("Lark.save() is only implemented for the LALR(1) parser.") + data, m = self.memo_serialize([TerminalDef, Rule]) + if exclude_options: + data["options"] = {n: v for n, v in data["options"].items() if n not in exclude_options} + pickle.dump({'data': data, 'memo': m}, f, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301 noqa - We're not using this functionality. + + @classmethod + def load(cls: Type[_T], f) -> _T: + #-- + inst = cls.__new__(cls) + return inst._load(f) + + def _deserialize_lexer_conf(self, data: Dict[str, Any], memo: Dict[int, Union[TerminalDef, Rule]], options: LarkOptions) -> LexerConf: + lexer_conf = LexerConf.deserialize(data['lexer_conf'], memo) + lexer_conf.callbacks = options.lexer_callbacks or {} + lexer_conf.re_module = regex if options.regex else re + lexer_conf.use_bytes = options.use_bytes + lexer_conf.g_regex_flags = options.g_regex_flags + lexer_conf.skip_validation = True + lexer_conf.postlex = options.postlex + return lexer_conf + + def _load(self: _T, f: Any, **kwargs) -> _T: + if isinstance(f, dict): + d = f + else: + d = pickle.load(f) # nosec B301 noqa - We're not using this functionality. + memo_json = d['memo'] + data = d['data'] + + assert memo_json + memo = SerializeMemoizer.deserialize(memo_json, {'Rule': Rule, 'TerminalDef': TerminalDef}, {}) + if 'grammar' in data: + self.grammar = Grammar.deserialize(data['grammar'], memo) + options = dict(data['options']) + if (set(kwargs) - _LOAD_ALLOWED_OPTIONS) & set(LarkOptions._defaults): + raise ConfigurationError("Some options are not allowed when loading a Parser: {}" + .format(set(kwargs) - _LOAD_ALLOWED_OPTIONS)) + options.update(kwargs) + self.options = LarkOptions.deserialize(options, memo) + self.rules = [Rule.deserialize(r, memo) for r in data['rules']] + self.source_path = '' + _validate_frontend_args(self.options.parser, self.options.lexer) + self.lexer_conf = self._deserialize_lexer_conf(data['parser'], memo, self.options) + self.terminals = self.lexer_conf.terminals + self._prepare_callbacks() + self._terminals_dict = {t.name: t for t in self.terminals} + self.parser = _deserialize_parsing_frontend( + data['parser'], + memo, + self.lexer_conf, + self._callbacks, + self.options, ## + + ) + return self + + @classmethod + def _load_from_dict(cls, data, memo, **kwargs): + inst = cls.__new__(cls) + return inst._load({'data': data, 'memo': memo}, **kwargs) + + @classmethod + def open(cls: Type[_T], grammar_filename: str, rel_to: Optional[str]=None, **options) -> _T: + #-- + if rel_to: + basepath = os.path.dirname(rel_to) + grammar_filename = os.path.join(basepath, grammar_filename) + with open(grammar_filename, encoding='utf8') as f: + return cls(f, **options) + + @classmethod + def open_from_package(cls: Type[_T], package: str, grammar_path: str, search_paths: 'Sequence[str]'=[""], **options) -> _T: + #-- + package_loader = FromPackageLoader(package, search_paths) + full_path, text = package_loader(None, grammar_path) + options.setdefault('source_path', full_path) + options.setdefault('import_paths', []) + options['import_paths'].append(package_loader) + return cls(text, **options) + + def __repr__(self): + return 'Lark(open(%r), parser=%r, lexer=%r, ...)' % (self.source_path, self.options.parser, self.options.lexer) + + + def lex(self, text: TextOrSlice, dont_ignore: bool=False) -> Iterator[Token]: + #-- + lexer: Lexer + if not hasattr(self, 'lexer') or dont_ignore: + lexer = self._build_lexer(dont_ignore) + else: + lexer = self.lexer + lexer_thread = LexerThread.from_text(lexer, text) + stream = lexer_thread.lex(None) + if self.options.postlex: + return self.options.postlex.process(stream) + return stream + + def get_terminal(self, name: str) -> TerminalDef: + #-- + return self._terminals_dict[name] + + def parse_interactive(self, text: Optional[LarkInput]=None, start: Optional[str]=None) -> 'InteractiveParser': + #-- + return self.parser.parse_interactive(text, start=start) + + def parse(self, text: LarkInput, start: Optional[str]=None, on_error: 'Optional[Callable[[UnexpectedInput], bool]]'=None) -> 'ParseTree': + #-- + if on_error is not None and self.options.parser != 'lalr': + raise NotImplementedError("The on_error option is only implemented for the LALR(1) parser.") + return self.parser.parse(text, start=start, on_error=on_error) + + + + +class DedentError(LarkError): + pass + +class Indenter(PostLex, ABC): + #-- + paren_level: int + indent_level: List[int] + + def __init__(self) -> None: + self.paren_level = 0 + self.indent_level = [0] + assert self.tab_len > 0 + + def handle_NL(self, token: Token) -> Iterator[Token]: + if self.paren_level > 0: + return + + yield token + + indent_str = token.rsplit('\n', 1)[1] ## + + indent = indent_str.count(' ') + indent_str.count('\t') * self.tab_len + + if indent > self.indent_level[-1]: + self.indent_level.append(indent) + yield Token.new_borrow_pos(self.INDENT_type, indent_str, token) + else: + while indent < self.indent_level[-1]: + self.indent_level.pop() + yield Token.new_borrow_pos(self.DEDENT_type, indent_str, token) + + if indent != self.indent_level[-1]: + raise DedentError('Unexpected dedent to column %s. Expected dedent to %s' % (indent, self.indent_level[-1])) + + def _process(self, stream): + token = None + for token in stream: + if token.type == self.NL_type: + yield from self.handle_NL(token) + else: + yield token + + if token.type in self.OPEN_PAREN_types: + self.paren_level += 1 + elif token.type in self.CLOSE_PAREN_types: + self.paren_level -= 1 + assert self.paren_level >= 0 + + while len(self.indent_level) > 1: + self.indent_level.pop() + yield Token.new_borrow_pos(self.DEDENT_type, '', token) if token else Token(self.DEDENT_type, '', 0, 0, 0, 0, 0, 0) + + assert self.indent_level == [0], self.indent_level + + def process(self, stream): + self.paren_level = 0 + self.indent_level = [0] + return self._process(stream) + + ## + + @property + def always_accept(self): + return (self.NL_type,) + + @property + @abstractmethod + def NL_type(self) -> str: + #-- + raise NotImplementedError() + + @property + @abstractmethod + def OPEN_PAREN_types(self) -> List[str]: + #-- + raise NotImplementedError() + + @property + @abstractmethod + def CLOSE_PAREN_types(self) -> List[str]: + #-- + raise NotImplementedError() + + @property + @abstractmethod + def INDENT_type(self) -> str: + #-- + raise NotImplementedError() + + @property + @abstractmethod + def DEDENT_type(self) -> str: + #-- + raise NotImplementedError() + + @property + @abstractmethod + def tab_len(self) -> int: + #-- + raise NotImplementedError() + + +class PythonIndenter(Indenter): + #-- + + NL_type = '_NEWLINE' + OPEN_PAREN_types = ['LPAR', 'LSQB', 'LBRACE'] + CLOSE_PAREN_types = ['RPAR', 'RSQB', 'RBRACE'] + INDENT_type = '_INDENT' + DEDENT_type = '_DEDENT' + tab_len = 8 + + +import pickle, zlib, base64 +DATA = ( +{'parser': {'lexer_conf': {'terminals': [{'@': 0}, {'@': 1}, {'@': 2}, {'@': 3}, {'@': 4}, {'@': 5}], 'ignore': [], 'g_regex_flags': 0, 'use_bytes': False, 'lexer_type': 'contextual', '__type__': 'LexerConf'}, 'parser_conf': {'rules': [{'@': 6}, {'@': 7}, {'@': 8}, {'@': 9}, {'@': 10}, {'@': 11}, {'@': 12}, {'@': 13}, {'@': 14}, {'@': 15}, {'@': 16}, {'@': 17}, {'@': 18}, {'@': 19}, {'@': 20}, {'@': 21}, {'@': 22}, {'@': 23}, {'@': 24}], 'start': ['start'], 'parser_type': 'lalr', '__type__': 'ParserConf'}, 'parser': {'tokens': {0: '__ANON_0', 1: 'SPACE', 2: '$END', 3: 'EQUAL', 4: '__ANON_2', 5: 'value_with_spaces', 6: '__parameter_list_plus_0', 7: '__parameter_list_star_1', 8: 'DBLQUOTE', 9: 'bare_value', 10: 'value', 11: '__ANON_1', 12: 'quoted_value', 13: 'key', 14: 'parameter', 15: 'key_value_pair', 16: 'kernel_command_line', 17: 'parameter_list', 18: 'start'}, 'states': {0: {0: (1, {'@': 21}), 1: (1, {'@': 21})}, 1: {2: (1, {'@': 16}), 1: (1, {'@': 16})}, 2: {2: (1, {'@': 17}), 1: (1, {'@': 17})}, 3: {2: (1, {'@': 6})}, 4: {2: (1, {'@': 19}), 1: (1, {'@': 19})}, 5: {1: (1, {'@': 15}), 3: (1, {'@': 15}), 2: (1, {'@': 15})}, 6: {4: (0, 24), 5: (0, 8)}, 7: {1: (0, 0), 6: (0, 22), 7: (0, 20), 2: (1, {'@': 9})}, 8: {8: (0, 12)}, 9: {2: (1, {'@': 7})}, 10: {2: (1, {'@': 23}), 1: (1, {'@': 23})}, 11: {2: (1, {'@': 13}), 1: (1, {'@': 13})}, 12: {2: (1, {'@': 18}), 1: (1, {'@': 18})}, 13: {0: (1, {'@': 22}), 1: (1, {'@': 22})}, 14: {2: (1, {'@': 24}), 1: (1, {'@': 24})}, 15: {}, 16: {9: (0, 1), 10: (0, 19), 11: (0, 4), 8: (0, 6), 12: (0, 2)}, 17: {1: (0, 0), 6: (0, 21), 2: (1, {'@': 10})}, 18: {3: (0, 16), 2: (1, {'@': 12}), 1: (1, {'@': 12})}, 19: {2: (1, {'@': 14}), 1: (1, {'@': 14})}, 20: {1: (0, 0), 6: (0, 21), 2: (1, {'@': 8})}, 21: {13: (0, 18), 14: (0, 14), 0: (0, 5), 1: (0, 13), 15: (0, 11)}, 22: {13: (0, 18), 0: (0, 5), 14: (0, 10), 1: (0, 13), 15: (0, 11)}, 23: {16: (0, 3), 13: (0, 18), 14: (0, 7), 1: (0, 0), 0: (0, 5), 6: (0, 22), 15: (0, 11), 17: (0, 9), 18: (0, 15), 7: (0, 17), 2: (1, {'@': 11})}, 24: {8: (1, {'@': 20})}}, 'start_states': {'start': 23}, 'end_states': {'start': 15}}, '__type__': 'ParsingFrontend'}, 'rules': [{'@': 6}, {'@': 7}, {'@': 8}, {'@': 9}, {'@': 10}, {'@': 11}, {'@': 12}, {'@': 13}, {'@': 14}, {'@': 15}, {'@': 16}, {'@': 17}, {'@': 18}, {'@': 19}, {'@': 20}, {'@': 21}, {'@': 22}, {'@': 23}, {'@': 24}], 'options': {'debug': False, 'strict': False, 'keep_all_tokens': False, 'tree_class': None, 'cache': False, 'cache_grammar': False, 'postlex': None, 'parser': 'lalr', 'lexer': 'contextual', 'transformer': None, 'start': ['start'], 'priority': 'normal', 'ambiguity': 'auto', 'regex': False, 'propagate_positions': False, 'lexer_callbacks': {}, 'maybe_placeholders': False, 'edit_terminals': None, 'g_regex_flags': 0, 'use_bytes': False, 'ordered_sets': True, 'import_paths': [], 'source_path': None, '_plugins': {}}, '__type__': 'Lark'} +) +MEMO = ( +{0: {'name': 'SPACE', 'pattern': {'value': ' ', 'flags': [], 'raw': '" "', '__type__': 'PatternStr'}, 'priority': 0, '__type__': 'TerminalDef'}, 1: {'name': 'EQUAL', 'pattern': {'value': '=', 'flags': [], 'raw': '"="', '__type__': 'PatternStr'}, 'priority': 0, '__type__': 'TerminalDef'}, 2: {'name': '__ANON_0', 'pattern': {'value': '[A-Za-z0-9_\\-\\.]+', 'flags': [], 'raw': '/[A-Za-z0-9_\\-\\.]+/', '_width': [1, 18446744073709551616], '__type__': 'PatternRE'}, 'priority': 0, '__type__': 'TerminalDef'}, 3: {'name': 'DBLQUOTE', 'pattern': {'value': '"', 'flags': [], 'raw': '"\\""', '__type__': 'PatternStr'}, 'priority': 0, '__type__': 'TerminalDef'}, 4: {'name': '__ANON_1', 'pattern': {'value': '[\\!\\#-\\\\.0-9:-\\@A-Za-z\\[-~]+', 'flags': [], 'raw': '/[\\!\\#-\\\\.0-9:-\\@A-Za-z\\[-~]+/', '_width': [1, 18446744073709551616], '__type__': 'PatternRE'}, 'priority': 0, '__type__': 'TerminalDef'}, 5: {'name': '__ANON_2', 'pattern': {'value': '[\\!\\#-\\\\.0-9:-\\@A-Za-z\\[-~ ]+', 'flags': [], 'raw': '/[\\!\\#-\\\\.0-9:-\\@A-Za-z\\[-~ ]+/', '_width': [1, 18446744073709551616], '__type__': 'PatternRE'}, 'priority': 0, '__type__': 'TerminalDef'}, 6: {'origin': {'name': 'start', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'kernel_command_line', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': True, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 7: {'origin': {'name': 'kernel_command_line', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'parameter_list', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 8: {'origin': {'name': 'parameter_list', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'parameter', '__type__': 'NonTerminal'}, {'name': '__parameter_list_star_1', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 9: {'origin': {'name': 'parameter_list', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'parameter', '__type__': 'NonTerminal'}], 'order': 1, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 10: {'origin': {'name': 'parameter_list', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__parameter_list_star_1', '__type__': 'NonTerminal'}], 'order': 2, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 11: {'origin': {'name': 'parameter_list', '__type__': 'NonTerminal'}, 'expansion': [], 'order': 3, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 12: {'origin': {'name': 'parameter', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'key', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 13: {'origin': {'name': 'parameter', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'key_value_pair', '__type__': 'NonTerminal'}], 'order': 1, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 14: {'origin': {'name': 'key_value_pair', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'key', '__type__': 'NonTerminal'}, {'name': 'EQUAL', 'filter_out': True, '__type__': 'Terminal'}, {'name': 'value', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 15: {'origin': {'name': 'key', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__ANON_0', 'filter_out': False, '__type__': 'Terminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 16: {'origin': {'name': 'value', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'bare_value', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 17: {'origin': {'name': 'value', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'quoted_value', '__type__': 'NonTerminal'}], 'order': 1, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 18: {'origin': {'name': 'quoted_value', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'DBLQUOTE', 'filter_out': True, '__type__': 'Terminal'}, {'name': 'value_with_spaces', '__type__': 'NonTerminal'}, {'name': 'DBLQUOTE', 'filter_out': True, '__type__': 'Terminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 19: {'origin': {'name': 'bare_value', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__ANON_1', 'filter_out': False, '__type__': 'Terminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 20: {'origin': {'name': 'value_with_spaces', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__ANON_2', 'filter_out': False, '__type__': 'Terminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 21: {'origin': {'name': '__parameter_list_plus_0', '__type__': 'NonTerminal'}, 'expansion': [{'name': 'SPACE', 'filter_out': True, '__type__': 'Terminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 22: {'origin': {'name': '__parameter_list_plus_0', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__parameter_list_plus_0', '__type__': 'NonTerminal'}, {'name': 'SPACE', 'filter_out': True, '__type__': 'Terminal'}], 'order': 1, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 23: {'origin': {'name': '__parameter_list_star_1', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__parameter_list_plus_0', '__type__': 'NonTerminal'}, {'name': 'parameter', '__type__': 'NonTerminal'}], 'order': 0, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}, 24: {'origin': {'name': '__parameter_list_star_1', '__type__': 'NonTerminal'}, 'expansion': [{'name': '__parameter_list_star_1', '__type__': 'NonTerminal'}, {'name': '__parameter_list_plus_0', '__type__': 'NonTerminal'}, {'name': 'parameter', '__type__': 'NonTerminal'}], 'order': 1, 'alias': None, 'options': {'keep_all_tokens': False, 'expand1': False, 'priority': None, 'template_source': None, 'empty_indices': (), '__type__': 'RuleOptions'}, '__type__': 'Rule'}} +) +Shift = 0 +Reduce = 1 +def Lark_StandAlone(**kwargs): + return Lark._load_from_dict(DATA, MEMO, **kwargs) diff --git a/ironic/common/kernel_parameters.py b/ironic/common/kernel_parameters.py new file mode 100644 index 0000000000..ca0cc924ef --- /dev/null +++ b/ironic/common/kernel_parameters.py @@ -0,0 +1,159 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from dataclasses import dataclass + +from ironic.common.exception import InvalidParameterValue +from ironic.common.i18n import _ +from ironic.common.kernel_parameter_parser.kernel_parameter_parser \ + import Lark_StandAlone +from ironic.common.kernel_parameter_parser.kernel_parameter_parser \ + import LarkError +from ironic.common.kernel_parameter_parser.kernel_parameter_parser \ + import Transformer +from ironic.common.kernel_parameter_parser.kernel_parameter_parser \ + import UnexpectedInput + + +def sanitize_kernel_command_line(command_line: str) -> str: + """Applies filtering to a command line to sanitize it. + + NOTE: This does not guarantee a correct or safe kernel command line, + for stronger guarantees of correctness and safety use + KernelCommandLine.parse(). + + :param command_line: A string containing a kernel command line or + individual parameters. + :returns: A filtered string which should be safer for use. + """ + return ''.join(c for c in command_line if c not in {'\n', '\r', '\0'}) + + +KernelParameterParser = Lark_StandAlone(debug=True) + + +@dataclass(frozen=True) +class ParameterKey: + key: str + + def __str__(self): + return self.key + + +@dataclass(frozen=True) +class ParameterValue: + value: str + + def __str__(self): + if ' ' in self.value: + return f"\"{self.value}\"" + return self.value + + +@dataclass(frozen=True) +class KernelParameter: + key: ParameterKey + value: ParameterValue + + def __str__(self): + if len(self.value.value) > 0: + return f"{self.key.key}={self.value.value}" + return self.key.key + + +_INIT_ARG_PREAMBLE = " -- " + + +# NOTE(clif): We're handling init args here instead of inside the grammar +# because Lark's stand-alone LALR(1) parser can't handle it. +def _divide_command_line_by_init_args(command_line: str) -> tuple[str, str]: + index = command_line.rfind(_INIT_ARG_PREAMBLE) + if index == -1: + return (command_line, '') + return (command_line[:index], + command_line[index + len(_INIT_ARG_PREAMBLE):]) + + +@dataclass(frozen=True) +class KernelCommandLine: + parameters: dict[str, list[KernelParameter]] + init_args: str + + def __str__(self): + output = ' '.join( + ' '.join(str(param) for param in param_list) + for param_list in self.parameters.values()) + if len(self.init_args) > 0: + output += _INIT_ARG_PREAMBLE + self.init_args + return output + + @classmethod + def parse(cls, command_line: str): + try: + cmd_line, init_args = \ + _divide_command_line_by_init_args(command_line) + tree = KernelParameterParser.parse(cmd_line) + kcl = KernelParameterTransformer().transform(tree) + return KernelCommandLine(kcl.parameters, init_args) + except (LarkError, UnexpectedInput) as e: + raise InvalidParameterValue( + _('Kernel command line did not parse: "%s" -- %s') + % (command_line, str(e))) from None + + +class KernelParameterTransformer(Transformer): + def kernel_command_line(self, items): + # NOTE(clif) adding init arguments to the grammar is too much for + # Lark's stand-alone LALR(1) parser. Therefore it isn't part of the + # back-ported grammar. + return KernelCommandLine(items[0], '') + + def parameter_list(self, items): + parameters = {} + for item in items: + if item.key.key in parameters.keys(): + parameters[item.key.key].append(item) + else: + parameters[item.key.key] = [item] + return parameters + + def parameter(self, items): + if isinstance(items[0], ParameterKey): + return KernelParameter(items[0], ParameterValue("")) + return items[0] + + def key_value_pair(self, items): + key = items[0] + value = items[1] + return KernelParameter(key, value) + + def key(self, items): + return ParameterKey(items[0].value) + + def value(self, items): + return ParameterValue(items[0]) + + def quoted_value(self, items): + # Strip " characters from literal. + return items[0].value[1:-1] + + def bare_value(self, items): + return items[0].value + + def value_with_spaces(self, items): + return items[0].value + + def init_suffix(self, items): + return items[0] + + def init_arguments(self, items): + return items[0] diff --git a/ironic/common/pxe_utils.py b/ironic/common/pxe_utils.py index b2b463ca06..ed619435e4 100644 --- a/ironic/common/pxe_utils.py +++ b/ironic/common/pxe_utils.py @@ -972,6 +972,8 @@ def build_pxe_config_options(task, pxe_info, service=False, as kernel command-line arguments. :returns: A dictionary of pxe options to be used in the pxe bootfile template. + + :raises: InvalidParameterValue via get_kernel_append_params """ node = task.node mode = deploy_utils.rescue_or_deploy_mode(node) diff --git a/ironic/conf/conductor.py b/ironic/conf/conductor.py index a36188964a..9bc0e2522d 100644 --- a/ironic/conf/conductor.py +++ b/ironic/conf/conductor.py @@ -546,6 +546,18 @@ 'here are validated as absolute paths and will be rejected' 'if they contain path traversal mechanisms, such as "..".' )), + cfg.BoolOpt('disable_kernel_parameter_parsing', + default=False, + # Normally such an option would be mutable, but this is, + # a security guard and operators should not expect to change + # this option under normal circumstances. + mutable=False, + help=_('Disable parsing of kernel parameters. Kernel ' + 'parameter parsing allows Ironic to detect and prevent ' + 'malformed kernel parameters before they are passed to ' + 'nodes. Malformed kernel parameters can pose a ' + 'security risk therefore it is not recommended to ' + 'disable this option unless absolutely necessary.')), ] diff --git a/ironic/drivers/utils.py b/ironic/drivers/utils.py index 882ec9ad8f..8627da6832 100644 --- a/ironic/drivers/utils.py +++ b/ironic/drivers/utils.py @@ -24,6 +24,7 @@ from ironic.common import exception from ironic.common.i18n import _ +from ironic.common import kernel_parameters from ironic.common import states from ironic.common import swift from ironic.conductor import utils @@ -437,11 +438,30 @@ def get_kernel_append_params(node, default): :param node: Node object. :param default: Default value. + + :raises: InvalidParameterValue if kernel_append_params is an invalid + string to append to a kernel command line. """ for location in ('instance_info', 'driver_info'): result = getattr(node, location).get('kernel_append_params') if result is not None: - return result.replace('%default%', default or '') + result = result.replace('%default%', default or '') + + if not CONF.conductor.disable_kernel_parameter_parsing: + # NOTE(clif) Attempt to parse the append params. Failure to + # parse indicates malformed kernel parameters and should be + # rejected. parse() will raise if parsing fails. + try: + kernel_parameters.KernelCommandLine.parse(result) + except exception.InvalidParameterValue: + raise exception.InvalidParameterValue( + _('node\'s %s[\'kernel_append_params\'] contains ' + 'malformed kernel command line') % location) + + # NOTE(clif) Always run basic sanitization on kernel_append_params + result = kernel_parameters.sanitize_kernel_command_line(result) + + return result return default diff --git a/ironic/tests/unit/common/test_kernel_parameters.py b/ironic/tests/unit/common/test_kernel_parameters.py new file mode 100644 index 0000000000..f64f46eb51 --- /dev/null +++ b/ironic/tests/unit/common/test_kernel_parameters.py @@ -0,0 +1,230 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from ironic.common.exception import InvalidParameterValue +import ironic.common.kernel_parameters as kp +from ironic.tests import base + +from ddt import data +from ddt import ddt +from ddt import unpack + + +def annotate(name, *args): + class AnnotatedList(list): + pass + + al = AnnotatedList([*args]) + al.__name__ = name + return al + + +def generate_invalid_characters_to_test(): + invalid_characters_to_test = [ + chr(c) for c in range(0, 32) + ] + invalid_characters_to_test.extend([ + "\n", + "\r", + chr(127), + ]) + invalid_characters_to_test.extend([ + chr(c) for c in range(128, 160) + ]) + return invalid_characters_to_test + + +INVALID_CHARACTERS = generate_invalid_characters_to_test() + + +class KernelParamTryout(base.TestCase): + def test_if_can_parse(self): + result = kp.KernelCommandLine.parse("quiet") + self.assertIsNotNone(result) + + +@ddt +class KernelParametersTestCase(base.TestCase): + @data( + annotate( + "Filtering newlines", + "quiet\n", + "quiet" + ), + annotate( + "Filtering carraige returns", + "qu\riet", + "quiet" + ), + annotate( + "Filtering NULL", + "\0quiet", + "quiet" + ), + annotate( + "Nothing needs changing - a real valid kernel cmdline", + ("BOOT_IMAGE=(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64 " + "root=UUID=217c8a40-4956-11f1-9c98-d8bbc1c85452 ro " + "rootflags=subvol=root " + "rd.luks.uuid=luks-3a516752-4956-11f1-aa13-d8bbc1c85452 " + "rhgb quiet rd.driver.blacklist=nouveau,nova_core " + "modprobe.blacklist=nouveau,nova_core"), + ("BOOT_IMAGE=(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64 " + "root=UUID=217c8a40-4956-11f1-9c98-d8bbc1c85452 ro " + "rootflags=subvol=root " + "rd.luks.uuid=luks-3a516752-4956-11f1-aa13-d8bbc1c85452 " + "rhgb quiet rd.driver.blacklist=nouveau,nova_core " + "modprobe.blacklist=nouveau,nova_core") + ), + ) + @unpack + def test_sanitize_kernel_command_line( + self, command_line: str, expected_result: str): + self.assertEqual( + expected_result, + kp.sanitize_kernel_command_line(command_line)) + + @data( + annotate( + "Single key=value pair", + "BOOT_IMAGE=(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64", + kp.KernelCommandLine({ + 'BOOT_IMAGE': [kp.KernelParameter( + kp.ParameterKey('BOOT_IMAGE'), + kp.ParameterValue( + '(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64') + )], + }, "") + ), + annotate( + "Single key", + "quiet", + kp.KernelCommandLine({ + 'quiet': [kp.KernelParameter( + kp.ParameterKey('quiet'), + kp.ParameterValue(''), + )], + }, "") + ), + annotate( + "Two parameters", + "quiet BOOT_IMAGE=(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64", + kp.KernelCommandLine({ + 'quiet': [kp.KernelParameter( + kp.ParameterKey('quiet'), + kp.ParameterValue(''), + )], + 'BOOT_IMAGE': [kp.KernelParameter( + kp.ParameterKey('BOOT_IMAGE'), + kp.ParameterValue( + '(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64') + )], + }, "") + ), + annotate( + "A real linux kernel cmdline", + ("BOOT_IMAGE=(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64 " + "root=UUID=217c8a40-4956-11f1-9c98-d8bbc1c85452 ro " + "rootflags=subvol=root " + "rd.luks.uuid=luks-3a516752-4956-11f1-aa13-d8bbc1c85452 " + "rhgb quiet rd.driver.blacklist=nouveau,nova_core " + "modprobe.blacklist=nouveau,nova_core"), + kp.KernelCommandLine({ + 'BOOT_IMAGE': [kp.KernelParameter( + kp.ParameterKey('BOOT_IMAGE'), + kp.ParameterValue( + '(hd5,gpt2)/vmlinuz-6.19.9-200.fc43.x86_64') + )], + 'root': [kp.KernelParameter( + kp.ParameterKey('root'), + kp.ParameterValue( + 'UUID=217c8a40-4956-11f1-9c98-d8bbc1c85452'), + )], + 'ro': [kp.KernelParameter( + kp.ParameterKey('ro'), + kp.ParameterValue(''), + )], + 'rootflags': [kp.KernelParameter( + kp.ParameterKey('rootflags'), + kp.ParameterValue('subvol=root'), + )], + 'rd.luks.uuid': [kp.KernelParameter( + kp.ParameterKey('rd.luks.uuid'), + kp.ParameterValue( + 'luks-3a516752-4956-11f1-aa13-d8bbc1c85452'), + )], + 'rhgb': [kp.KernelParameter( + kp.ParameterKey('rhgb'), + kp.ParameterValue(''), + )], + 'quiet': [kp.KernelParameter( + kp.ParameterKey('quiet'), + kp.ParameterValue(''), + )], + 'rd.driver.blacklist': [kp.KernelParameter( + kp.ParameterKey('rd.driver.blacklist'), + kp.ParameterValue('nouveau,nova_core'), + )], + 'modprobe.blacklist': [kp.KernelParameter( + kp.ParameterKey('modprobe.blacklist'), + kp.ParameterValue('nouveau,nova_core'), + )], + }, "") + ), + annotate( + "Multiple parameters with the same key", + "initrd=/initramfs-linux.img initrd=ramdisk", + kp.KernelCommandLine({ + 'initrd': [ + kp.KernelParameter( + kp.ParameterKey('initrd'), + kp.ParameterValue('/initramfs-linux.img') + ), + kp.KernelParameter( + kp.ParameterKey('initrd'), + kp.ParameterValue('ramdisk') + ) + ]}, "") + ), + annotate( + "init arguments", + "quiet -- some init args", + kp.KernelCommandLine({ + 'quiet': [kp.KernelParameter( + kp.ParameterKey('quiet'), + kp.ParameterValue(''), + )], + }, "some init args") + ), + ) + @unpack + def test_kernel_command_line_parsing( + self, command_line: str, expected_result: kp.KernelCommandLine): + result = kp.KernelCommandLine.parse(command_line) + # Assert parsing the command line spits out the expected + # object. + self.assertEqual(expected_result, result) + # Assert rendering the object back to a string matches the initial + # command line string. + self.assertEqual(command_line, str(result)) + + @data( + *[annotate( + f"character ordinal {ord(c)} shouldn't parse", + f"ro{c}quiet",) for c in INVALID_CHARACTERS] + ) + @unpack + def test_invalid_kernel_command_lines_fail_to_parse( + self, command_line: str): + self.assertRaises(InvalidParameterValue, + kp.KernelCommandLine.parse, + command_line) diff --git a/ironic/tests/unit/drivers/test_utils.py b/ironic/tests/unit/drivers/test_utils.py index f2e79e8271..dc07f4dca2 100644 --- a/ironic/tests/unit/drivers/test_utils.py +++ b/ironic/tests/unit/drivers/test_utils.py @@ -13,10 +13,14 @@ # License for the specific language governing permissions and limitations # under the License. +from dataclasses import dataclass import datetime import os from unittest import mock +from ddt import data +from ddt import ddt +from ddt import unpack from oslo_config import cfg from oslo_utils import timeutils @@ -32,6 +36,15 @@ from ironic.tests.unit.objects import utils as obj_utils +def annotate(name, *args): + class AnnotatedList(list): + pass + + al = AnnotatedList([*args]) + al.__name__ = name + return al + + class UtilsTestCase(db_base.DbTestCase): def setUp(self): @@ -222,6 +235,90 @@ def test_normalize_mac_unicode(self): self.assertEqual("0a1b2c3d4f", mac_clean) +@ddt +class GetKernelAppendParamsTestCase(tests_base.TestCase): + @dataclass(frozen=True) + class FauxTestNode: + instance_info: dict + driver_info: dict + + @data( + annotate( + "valid params in instance_info", + FauxTestNode({'kernel_append_params': 'quiet ro'}, + {}), + '', + 'quiet ro', + False, + False + ), + annotate( + "valid params in driver_info", + FauxTestNode({}, + {'kernel_append_params': 'quiet ro'}), + '', + 'quiet ro', + False, + False + ), + annotate( + "params in default", + FauxTestNode({}, {}), + 'quiet ro', + 'quiet ro', + False, + False + ), + annotate( + "invalid params in instance_info raises", + FauxTestNode({'kernel_append_params': 'bad\nparams'}, {}), + '', + '', + True, + False + + ), + annotate( + "invalid params in driver_info raises", + FauxTestNode({}, {'kernel_append_params': 'bad\nparams'}), + '', + '', + True, + False + ), + annotate( + "parsing disabled - but newline is filtered", + FauxTestNode({}, {'kernel_append_params': 'quiet\n ro'}), + '', + 'quiet ro', + False, + True, + ), + ) + @unpack + def test_get_kernel_append_params( + self, + test_node: FauxTestNode, + default: str, + expected_result: str, + should_raise: bool, + disable_kernel_parameter_parsing: bool): + cfg.CONF.set_override('disable_kernel_parameter_parsing', + disable_kernel_parameter_parsing, + 'conductor') + if should_raise: + self.assertRaises( + exception.InvalidParameterValue, + driver_utils.get_kernel_append_params, + test_node, + default) + else: + self.assertEqual( + expected_result, + driver_utils.get_kernel_append_params(test_node, + default)) + + class UtilsRamdiskLogsTestCase(tests_base.TestCase): def setUp(self): diff --git a/releasenotes/notes/sanitize-kernel-append-params-8b2953a9d903d0f6.yaml b/releasenotes/notes/sanitize-kernel-append-params-8b2953a9d903d0f6.yaml new file mode 100644 index 0000000000..90e136f769 --- /dev/null +++ b/releasenotes/notes/sanitize-kernel-append-params-8b2953a9d903d0f6.yaml @@ -0,0 +1,14 @@ +--- +security: + - | + Fixes a security issue where a malicious 'kernel_append_params' in a node's + 'instance_info' or 'driver_info' could cause an attacker to take control of + a node's initial boot through boot script injection. The fix includes + sanitization of 'kernel_append_params' to prevent such injection. Strict + parsing of kernel parameters is now in place as well and enabled by + default. If an operator needs to disable such strict parsing they may do + so by setting the configuration option + conductor.disable_kernel_parameter_parsing to 'True'. However, this is + discouraged as it weakens the security posture of Ironic. This fix + addresses CVE-2026-46447. This back-port utilizes a generated, stand-alone + lark parser governed by the MPL v2.0 license. From fbe121c2a505b2217532d891ca27e253b777216a Mon Sep 17 00:00:00 2001 From: Julia Kreger Date: Wed, 20 May 2026 15:31:38 -0700 Subject: [PATCH 2/3] security: directory transversal ISO9660 support A vulnerability was identified in Ironic's handling of ISO images where Ironic contains support to patch ISO9660 virtual media contents to include key data items like configuration drive data and other required metadata. Anyhow, the issue here was the Ironic service was trusting that the submitted contents were valid, and a directory transversal attempt could be embedded within a modified configuration drive ISO contents submitted to Ironic. This is a case where an attacker would take a path in an ISO, and attempt to directly modify it to reach another path on the filesystem which was within the confines and path structure they were working with. i.e. while "../foo" is not a valid file or directory name in ISO9660, it can still be represented, injected, and read by the pycdlib library. The code on all paths which perform this type of ISO content interaction have been patched to explicitly check the path for transversal attempts and internally raises an InvalidContent exception. Impacted features: * Virtual Media ISO patching code path for pre-generated deployment ISOs as opposed to Ironic generated ISOs from a kernel/ramdisk. * Anaconda deployment interfacce where a user could impact the resulting pathing on the node being deployed at deploy time. Related-Bug: 2148333 Change-Id: I09ba308dc5088260594a502d43413d2069616703 Signed-off-by: Julia Kreger --- ironic/common/exception.py | 6 + ironic/common/images.py | 11 +- ironic/common/kickstart_utils.py | 5 + ironic/common/utils.py | 35 ++++++ ironic/tests/unit/common/test_images.py | 111 ++++++++++++++++++ .../tests/unit/common/test_kickstart_utils.py | 83 +++++++++++++ .../notes/bug-2148333-b3a74b813eea7dab.yaml | 17 +++ 7 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 releasenotes/notes/bug-2148333-b3a74b813eea7dab.yaml diff --git a/ironic/common/exception.py b/ironic/common/exception.py index 3d00882b30..23b3a23a0d 100644 --- a/ironic/common/exception.py +++ b/ironic/common/exception.py @@ -1140,3 +1140,9 @@ class RuleActionExecutionFailure(InspectionRuleExecutionFailure): class RuleConditionCheckFailure(InspectionRuleExecutionFailure): """Raised when an inspection rule condition fails during execution.""" _msg_fmt = _("Inspection rule condition check failed. Reason: %(reason)s") + + +class InvalidContent(Invalid): + """Invalid or malicious content has been provided to the conductor.""" + _msg_fmt = _("Invalid or potentially malicious content has been provided " + "to the conductor and the conductor will not proceed.") diff --git a/ironic/common/images.py b/ironic/common/images.py index 3b1915ee53..647cf66efe 100644 --- a/ironic/common/images.py +++ b/ironic/common/images.py @@ -828,15 +828,22 @@ def _extract_iso(extract_iso, extract_dir): iso.open(extract_iso) for dirname, dirlist, filelist in iso.walk(iso_path='/'): + # NOTE(TheJulia): This code is confusing because the way the + # walk method returns data, basically a list of tuples to + # provide a structural mapping which a consumer of the walk + # data can use to understand the structure. dir_path = dirname.lstrip('/') + utils.check_iso_path(extract_dir, dir_path) for dir_iso in dirlist: + utils.check_iso_path(extract_dir, + os.path.join(dir_path, dir_iso)) os.makedirs(os.path.join(extract_dir, dir_path, dir_iso)) for file in filelist: - file_path = os.path.join(extract_dir, dirname, file) + utils.check_iso_path(extract_dir, dir_path, file) + file_path = os.path.join(dirname, file) iso.get_file_from_iso( os.path.join(extract_dir, dir_path, file), iso_path=file_path) - iso.close() diff --git a/ironic/common/kickstart_utils.py b/ironic/common/kickstart_utils.py index 317c5620d0..a05f5dfa5e 100644 --- a/ironic/common/kickstart_utils.py +++ b/ironic/common/kickstart_utils.py @@ -51,6 +51,11 @@ def _get_config_drive_dict_from_iso( # server. posix_file_path = posix_file_path.lstrip('/') target_file_path = os.path.join(target_path, posix_file_path) + real_target_file_path = os.path.realpath(target_file_path) + if not target_file_path.startswith(real_target_file_path): + LOG.error('Discovered transversal attempt while reading ' + 'the configuration drive contents.') + raise exception.InvalidContent() b_buf = io.BytesIO() iso_reader.get_file_from_iso_fp( iso_path=iso_file_path, outfp=b_buf diff --git a/ironic/common/utils.py b/ironic/common/utils.py index 0182173011..9dff1655de 100644 --- a/ironic/common/utils.py +++ b/ironic/common/utils.py @@ -1148,3 +1148,38 @@ def get_route_source(dest, ignore_link_local=True): except (IndexError, ValueError): LOG.debug('No route to host %(dest)s, route record: %(rec)s', {'dest': dest, 'rec': out}) + + +def check_iso_path(base_folder, folder, file=None): + """Sanity check an ISO path, folder, and file structure. + + :param base_folder: The target folder for path operations. + :param folder: The folder being evaluated in the ISO. + :param file: An optional file to also evaluate for path + transversal attempts. + :raises: InvalidContent when an inconsistency is detected. + """ + if folder.startswith('/'): + # If we're here, we were handed the base folder path with a leading /. + # Any caller of this method should pre-emptively strip it. + raise exception.InvalidContent() + + target_folder = os.path.join(base_folder, folder) + resolved_folder = os.path.realpath(target_folder) + if not resolved_folder.startswith(base_folder): + # supplied folder path has something like ../ in the data set + # or resolution has resulted in a change in the value. + # Possible risk: if the temp folder is being used with a symlink... + LOG.error('ISO path evaluation identified a folder based ' + 'transversal attempt.') + raise exception.InvalidContent() + if file: + target_file_path = os.path.join(resolved_folder, file) + resolved_file_path = os.path.realpath(target_file_path) + # Check that the folder itself doesn't change, + # and then check that the resolved path matches + if (not target_file_path.startswith(resolved_folder) + or target_file_path != resolved_file_path): + LOG.error('ISO path evaluation identified a file name based ' + 'transversal attempt.') + raise exception.InvalidContent() diff --git a/ironic/tests/unit/common/test_images.py b/ironic/tests/unit/common/test_images.py index b51b96dae5..3b1ba44825 100644 --- a/ironic/tests/unit/common/test_images.py +++ b/ironic/tests/unit/common/test_images.py @@ -25,6 +25,7 @@ from oslo_config import cfg from oslo_utils import fileutils from oslo_utils.imageutils import format_inspector as image_format_inspector +import pycdlib from ironic.common import exception from ironic.common.glance_service import service_utils as glance_utils @@ -838,6 +839,116 @@ def test__generate_grub_cfg(self): options) self.assertEqual(expected_cfg, cfg) + @mock.patch.object(os, 'makedirs', autospec=True) + @mock.patch('pycdlib.PyCdlib', autospec=True) + def test__extract_iso(self, mock_pycdlib_cls, mock_makedirs): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', ['BOOT'], ['README.TXT']), + ('/BOOT', ['GRUB'], ['BOOTX64.EFI']), + ('/BOOT/GRUB', [], ['GRUB.CFG']), + ] + + images._extract_iso('/path/to/image.iso', '/extract') + + mock_iso.open.assert_called_once_with('/path/to/image.iso') + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_makedirs.assert_any_call( + os.path.join('/extract', '', 'BOOT')) + mock_makedirs.assert_any_call( + os.path.join('/extract', 'BOOT', 'GRUB')) + mock_iso.get_file_from_iso.assert_any_call( + os.path.join('/extract', '', 'README.TXT'), + iso_path=os.path.join('/extract', '/', 'README.TXT')) + mock_iso.get_file_from_iso.assert_any_call( + os.path.join('/extract', 'BOOT', 'BOOTX64.EFI'), + iso_path=os.path.join('/extract', '/BOOT', + 'BOOTX64.EFI')) + mock_iso.get_file_from_iso.assert_any_call( + os.path.join('/extract', 'BOOT/GRUB', 'GRUB.CFG'), + iso_path=os.path.join('/extract', '/BOOT/GRUB', + 'GRUB.CFG')) + self.assertEqual(3, mock_iso.get_file_from_iso.call_count) + mock_iso.close.assert_called_once() + + @mock.patch('pycdlib.PyCdlib', autospec=True) + def test__extract_iso_empty(self, mock_pycdlib_cls): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', [], []), + ] + + images._extract_iso('/path/to/empty.iso', '/extract') + + mock_iso.open.assert_called_once_with('/path/to/empty.iso') + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_iso.get_file_from_iso.assert_not_called() + mock_iso.close.assert_called_once() + + @mock.patch('pycdlib.PyCdlib', autospec=True) + def test__extract_iso_open_fails(self, mock_pycdlib_cls): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.open.side_effect = ( + pycdlib.pycdlibexception.PyCdlibInvalidInput( + msg='Could not open file')) + + self.assertRaises( + pycdlib.pycdlibexception.PyCdlibInvalidInput, + images._extract_iso, + '/path/to/bad.iso', '/extract') + mock_iso.walk.assert_not_called() + mock_iso.close.assert_not_called() + + @mock.patch.object(os, 'makedirs', autospec=True) + @mock.patch('pycdlib.PyCdlib', autospec=True) + def test__extract_iso_invalid_file(self, mock_pycdlib_cls, mock_makedirs): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', ['BOOT'], ['README.TXT']), + ('/BOOT', ['GRUB'], ['../TX64.EFI']), + ('/BOOT/GRUB', [], ['GRUB.CFG']), + ] + + self.assertRaises(exception.InvalidContent, + images._extract_iso, + '/path/to/image.iso', '/extract') + + mock_iso.open.assert_called_once_with('/path/to/image.iso') + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_makedirs.assert_any_call( + os.path.join('/extract', '', 'BOOT')) + mock_makedirs.assert_any_call( + os.path.join('/extract', 'BOOT', 'GRUB')) + mock_iso.get_file_from_iso.assert_any_call( + os.path.join('/extract', '', 'README.TXT'), + iso_path=os.path.join('/extract', '/', 'README.TXT')) + self.assertEqual(1, mock_iso.get_file_from_iso.call_count) + + @mock.patch.object(os, 'makedirs', autospec=True) + @mock.patch('pycdlib.PyCdlib', autospec=True) + def test__extract_iso_invalid_folder(self, mock_pycdlib_cls, + mock_makedirs): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', ['BOOT'], ['README.TXT']), + ('/../T', ['GRUB'], ['BOOTX64.EFI']), + ('/BOOT/GRUB', [], ['GRUB.CFG']), + ] + + self.assertRaises(exception.InvalidContent, + images._extract_iso, + '/path/to/image.iso', '/extract') + + mock_iso.open.assert_called_once_with('/path/to/image.iso') + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_makedirs.assert_any_call( + os.path.join('/extract', '', 'BOOT')) + self.assertEqual(1, mock_makedirs.call_count) + mock_iso.get_file_from_iso.assert_any_call( + os.path.join('/extract', '', 'README.TXT'), + iso_path=os.path.join('/extract', '/', 'README.TXT')) + self.assertEqual(1, mock_iso.get_file_from_iso.call_count) + @mock.patch.object(os.path, 'relpath', autospec=True) @mock.patch.object(os, 'walk', autospec=True) @mock.patch.object(images, '_extract_iso', autospec=True) diff --git a/ironic/tests/unit/common/test_kickstart_utils.py b/ironic/tests/unit/common/test_kickstart_utils.py index 52ea0324d7..125bb470ab 100644 --- a/ironic/tests/unit/common/test_kickstart_utils.py +++ b/ironic/tests/unit/common/test_kickstart_utils.py @@ -17,6 +17,7 @@ from unittest import mock from oslo_config import cfg +import pycdlib from ironic.common import kickstart_utils as ks_utils from ironic.conductor import task_manager @@ -100,3 +101,85 @@ def test_prepare_config_drive_in_swift(self, mock_get): self.assertEqual(expected, ks_utils.prepare_config_drive(task)) mock_get.assert_called_with('http://server/fake-configdrive-url', timeout=60) + + @mock.patch.object(pycdlib, 'PyCdlib', autospec=True) + def test_read_iso9600_config_drive(self, mock_pycdlib_cls): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', [], ['FILE1.TXT;1']), + ] + mock_record = mock.Mock() + mock_iso.get_record.return_value = mock_record + mock_iso.full_path_from_dirrecord.return_value = ( + '/openstack/latest/user_data' + ) + + def fake_get_file(iso_path, outfp): + outfp.write(b'test user_data') + + mock_iso.get_file_from_iso_fp.side_effect = fake_get_file + + result = ks_utils.read_iso9600_config_drive(b'fake-iso') + + mock_iso.open.assert_called_once() + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_iso.get_record.assert_called_once_with( + iso_path='/FILE1.TXT;1' + ) + mock_iso.full_path_from_dirrecord.assert_called_once_with( + mock_record, rockridge=True + ) + mock_iso.get_file_from_iso_fp.assert_called_once() + mock_iso.close.assert_called_once() + + expected_path = ( + '/var/lib/cloud/seed/config_drive' + '/openstack/latest/user_data' + ) + self.assertIn(expected_path, result) + self.assertEqual('test user_data', result[expected_path]) + + @mock.patch.object(pycdlib, 'PyCdlib', autospec=True) + def test_read_iso9600_config_drive_pycdlib_exception( + self, mock_pycdlib_cls): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.open.side_effect = ( + pycdlib.pycdlibexception.PyCdlibInvalidInput( + msg='bad iso' + ) + ) + + result = ks_utils.read_iso9600_config_drive(b'bad-data') + + mock_iso.open.assert_called_once() + mock_iso.walk.assert_not_called() + self.assertEqual({}, result) + + @mock.patch.object(pycdlib, 'PyCdlib', autospec=True) + def test_read_iso9600_config_drive_invalid_file(self, mock_pycdlib_cls): + mock_iso = mock_pycdlib_cls.return_value + mock_iso.walk.return_value = [ + ('/', [], ['../E1.TXT;1']), + ] + mock_record = mock.Mock() + mock_iso.get_record.return_value = mock_record + mock_iso.full_path_from_dirrecord.return_value = ( + '../E1.TXT' + ) + + def fake_get_file(iso_path, outfp): + outfp.write(b'test user_data') + + mock_iso.get_file_from_iso_fp.side_effect = fake_get_file + + returned = ks_utils.read_iso9600_config_drive(b'fake-iso') + self.assertEqual({}, returned) + mock_iso.open.assert_called_once() + mock_iso.walk.assert_called_once_with(iso_path='/') + mock_iso.get_record.assert_called_once_with( + iso_path='/../E1.TXT;1' + ) + mock_iso.full_path_from_dirrecord.assert_called_once_with( + mock_record, rockridge=True + ) + mock_iso.get_file_from_iso_fp.assert_not_called() diff --git a/releasenotes/notes/bug-2148333-b3a74b813eea7dab.yaml b/releasenotes/notes/bug-2148333-b3a74b813eea7dab.yaml new file mode 100644 index 0000000000..f7070f4544 --- /dev/null +++ b/releasenotes/notes/bug-2148333-b3a74b813eea7dab.yaml @@ -0,0 +1,17 @@ +--- +security: + - | + Fixes CVE-2026-48681 which was a lack of file path validation when + interacting with ISO9660 files in the kickstart/anaconda driver, + and administrative deployment ISOs which are admin-provided. + Ironic now explicitly rejects the submitted files when such + a case has been detected. +fixes: + - | + Fixes path handling around ISO9660 file handling as denoted in + `bug 2148333 `_ + under CVE-2026-48681 where a malicious user with deployment privilges + could craft an ISO9660 formatted configuration drive or deployment ramdisk + which includes intentional path manipulation which is non-compliant with + the underlying standard. Ironic now explicitly looks for such attempts + and rejects the content. From 769bcc2e40d9e49fba6c7d45a857dc67da789419 Mon Sep 17 00:00:00 2001 From: Julia Kreger Date: Thu, 7 May 2026 08:48:48 -0700 Subject: [PATCH 3/3] security: disable driver_info level pxe_template override A vulnerability report was filed pointing out a flaw in the pxe_template override logic where a direct file path was supplied. The original usage context of this minimally documented feature was that an operator, i.e. the owner of the ironic deployment could leverage a direct file path to a template on disk. This should instead have utilized the file:/// URL provider, but research suggests this feature has largely not been used. As a result, consensus has been reached amongst security maintainers for the Ironic project to disable and remove this functionality. Where this issue became a vulnerability for Ironic was the evolution of the usage and Role Based Access Control model where we began to separate the overall operator of the system from the administrative manager of the system. The resulting vector was that an authenticated and authorized user could potentially request a template a sensitive file to be sourced as the PXE template. This file would then be written to disk and utilized IF the ironic-conductor service could access it. The malicious authenticated and authorized user could then, if the environment was misconfigured, or operating with "flat" networking, it could be possible to guess the underlying file path on the tftpboot/httpboot network bootendpoints, and retrieve the rendered output before the deployment failed and the rendered output is removed. This is tracked as CVE-2026-44917, and the underlying feature is expected to be removed during the 2027.2. Closes-Bug: 2148319 Change-Id: I52daa344b4d417eee09c28b53703fea792e4367b Signed-off-by: Julia Kreger --- ironic/conf/pxe.py | 7 ++++ ironic/drivers/modules/deploy_utils.py | 36 ++++++++++++++++--- .../unit/drivers/modules/test_deploy_utils.py | 23 ++++++++++++ ...security-bug-2148319-49974afdcd38d9c0.yaml | 28 +++++++++++++++ 4 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/security-bug-2148319-49974afdcd38d9c0.yaml diff --git a/ironic/conf/pxe.py b/ironic/conf/pxe.py index e4c912192d..e84438a47e 100644 --- a/ironic/conf/pxe.py +++ b/ironic/conf/pxe.py @@ -209,6 +209,13 @@ '$pybasedir', 'drivers/modules/initial_grub_cfg.template'), help=_('On ironic-conductor node, the path to the initial grub' 'configuration template for grub network boot.')), + cfg.BoolOpt('enable_insecure_template_override', + default=False, + help=_('If node level pxe_template override is permitted to ' + 'be used in this Ironic deployment. This is an ' + 'insecure pattern filed under CVE-2026-44917 and ' + 'the feature this guards this is expected to be ' + 'removed in Ironic release 2027.2.')), ] diff --git a/ironic/drivers/modules/deploy_utils.py b/ironic/drivers/modules/deploy_utils.py index e0b2a61b9a..2f06f4a520 100644 --- a/ironic/drivers/modules/deploy_utils.py +++ b/ironic/drivers/modules/deploy_utils.py @@ -459,9 +459,22 @@ def get_ipxe_config_template(node): # loaders by architecture as they are all consistent. Where as PXE # could need to be grub for one arch, PXELINUX for another. configured_template = CONF.pxe.ipxe_config_template - override_template = node.driver_info.get('pxe_template') - if override_template: - configured_template = override_template + insecure_override_template = node.driver_info.get('pxe_template') + if CONF.pxe.enable_insecure_template_override: + # TODO(TheJulia): Remove the node level pxe_template setting in + # a future release as it is inhernetly insecure. + if insecure_override_template: + configured_template = insecure_override_template + elif insecure_override_template: + raise exception.InvalidParameterValue(_( + 'The node\'s driver_info field pxe_template override value is ' + 'insecure (CVE-2026-44917) and should not be used. The ' + 'appropriate approach is to utilize [pxe]ipxe_template_by_arch ' + 'configuration in ironic.conf to match the baremetal node\'s ' + 'architecture. Please work with your Ironic operator to remedy ' + 'your usage and configuration. Default templates may be ' + 'leveraged by deleting the pxe_template value in the driver_info ' + 'field.')) return configured_template or get_pxe_config_template(node) @@ -476,7 +489,22 @@ def get_pxe_config_template(node): :param node: A single Node. :returns: The PXE config template file name. """ - config_template = node.driver_info.get("pxe_template", None) + config_template = None + insecure_override_template = node.driver_info.get("pxe_template", None) + if CONF.pxe.enable_insecure_template_override: + # TODO(TheJulia): Remove the node level pxe_template setting in + # a future release as it is inhernetly insecure. + config_template = insecure_override_template + elif insecure_override_template: + raise exception.InvalidParameterValue(_( + 'The node\'s driver_info field pxe_template override value is ' + 'insecure (CVE-2026-44917) and should not be used. The ' + 'appropriate approach is to utilize [pxe]pxe_template_by_arch ' + 'configuration in ironic.conf to match the baremetal node\'s ' + 'architecture. Please work with your Ironic operator to remedy ' + 'your usage and configuration. Default templates may be ' + 'leveraged by deleting the pxe_template value in the driver_info ' + 'field.')) if config_template is None: cpu_arch = node.properties.get('cpu_arch') config_template = CONF.pxe.pxe_config_template_by_arch.get(cpu_arch) diff --git a/ironic/tests/unit/drivers/modules/test_deploy_utils.py b/ironic/tests/unit/drivers/modules/test_deploy_utils.py index fb8d64395f..3b0a8d59ef 100644 --- a/ironic/tests/unit/drivers/modules/test_deploy_utils.py +++ b/ironic/tests/unit/drivers/modules/test_deploy_utils.py @@ -424,6 +424,8 @@ def test_get_pxe_config_template_emtpy_property_bios(self): self.assertEqual('bios-template', result) def test_get_pxe_config_template_per_node(self): + cfg.CONF.set_override('enable_insecure_template_override', True, + group='pxe') node = obj_utils.create_test_node( self.context, driver='fake-hardware', driver_info={"pxe_template": "fake-template"}, @@ -431,6 +433,16 @@ def test_get_pxe_config_template_per_node(self): result = utils.get_pxe_config_template(node) self.assertEqual('fake-template', result) + def test_get_pxe_config_template_per_node_disabled(self): + self.assertFalse(cfg.CONF.pxe.enable_insecure_template_override) + node = obj_utils.create_test_node( + self.context, driver='fake-hardware', + driver_info={"pxe_template": "fake-template"}, + ) + self.assertRaisesRegex( + exception.InvalidParameterValue, 'CVE-2026-44917', + utils.get_pxe_config_template, node) + def test_get_ipxe_config_template(self): node = obj_utils.create_test_node( self.context, driver='fake-hardware') @@ -457,12 +469,23 @@ def test_get_ipxe_config_template_none_bios(self): utils.get_ipxe_config_template(node)) def test_get_ipxe_config_template_override_pxe_fallback(self): + cfg.CONF.set_override('enable_insecure_template_override', True, + group='pxe') node = obj_utils.create_test_node( self.context, driver='fake-hardware', driver_info={'pxe_template': 'magical'}) self.assertEqual('magical', utils.get_ipxe_config_template(node)) + def test_get_ipxe_config_template_override_pxe_fallback_disabled(self): + self.assertFalse(cfg.CONF.pxe.enable_insecure_template_override) + node = obj_utils.create_test_node( + self.context, driver='fake-hardware', + driver_info={'pxe_template': 'magical'}) + self.assertRaisesRegex( + exception.InvalidParameterValue, 'CVE-2026-44917', + utils.get_ipxe_config_template, node) + @mock.patch('time.sleep', lambda sec: None) class OtherFunctionTestCase(db_base.DbTestCase): diff --git a/releasenotes/notes/security-bug-2148319-49974afdcd38d9c0.yaml b/releasenotes/notes/security-bug-2148319-49974afdcd38d9c0.yaml new file mode 100644 index 0000000000..67edfc8f46 --- /dev/null +++ b/releasenotes/notes/security-bug-2148319-49974afdcd38d9c0.yaml @@ -0,0 +1,28 @@ +--- +security: + - | + A vulnerability was discovered in an minimally documented feature of + Ironic where an absolute path to a ``pxe_template`` override value could + be defined by an authenticated and privilged API user. The Ironic team has + chosen to immediately deprecate and remove this functionality. To provide + an immediate security fix, this functionality is now disabled by default. + The functionality can be re-enabled via the + ``[pxe]enable_insecure_template_override`` configuration option which + was added to ironic.conf with a default value of ``False``. + This issue is tracked as + `bug 2148319 `_. +fixes: + - | + Fixes a vulnerability (CVE-2026-44917) which was identified inhandling + of pxe_template overrides where an authenticated and authorized user + could request an override template via direct file path which would + bypass file URL handling guards introduced in OSSA-2025-001. This + feature was minimally documented through only a release note, and + does not appear to have actual use. This functionality is being + disabled by default, and will be promptly removed from Ironic's + current development branch. +deprecations: + - | + The node ``driver_info`` field value ``pxe_template`` has been + deprecated and is expected to be removed in the future Ironic + 2027.2 release.