diff --git a/aerospike_sdk/ael/_rust_fastpath.py b/aerospike_sdk/ael/_rust_fastpath.py new file mode 100644 index 0000000..7840c77 --- /dev/null +++ b/aerospike_sdk/ael/_rust_fastpath.py @@ -0,0 +1,624 @@ +"""Optional Rust-backed fast path for broad AEL surface parsing. + +This module accelerates parser-heavy AEL by delegating tokenization and +surface parsing to a Rust extension when available, then lowering the +normalized tree back into the same ``FilterExpression`` objects the Python +visitor would produce. + +Unsupported expressions fall back to the existing ANTLR path unchanged. +""" + +from __future__ import annotations + +import ast +import re +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional + +from antlr4 import CommonTokenStream, InputStream +from antlr4.error.ErrorListener import ErrorListener +from aerospike_async import FilterExpression, ListReturnType + +from aerospike_sdk.ael.antlr4.generated.ConditionLexer import ConditionLexer +from aerospike_sdk.ael.antlr4.generated.ConditionParser import ConditionParser +from aerospike_sdk.ael.exceptions import AelParseException +from aerospike_sdk.ael.exp_visitor import ( + ArithOp, + CDTPath, + DeferredArithmetic, + DeferredBin, + ExpressionConditionVisitor, + InferredType, + TypedExpr, + _build_arithmetic, + _contains_deferred, + _get_type_hint, + _infer_element_type, + _is_float_context, + _resolve_for_arithmetic, + _resolve_for_comparison, + _resolve_for_in_list, + _resolve_for_in_value, + _validate_arg_count, + _validate_min_arg_count, + _validate_in_type_compatibility, + _unquote, +) + +try: + from ael_rust_ext import parse_subset_rust as _parse_subset_rust +except Exception: + _parse_subset_rust = None + + +class RustFastPathUnsupported(Exception): + """Raised internally when the Rust fast path cannot handle an expression.""" + + +class _RustFastPathErrorListener(ErrorListener): + """Error listener that raises on first syntax error (no recovery).""" + + def syntaxError(self, recognizer, offending_symbol, line, column, msg, e): + raise AelParseException(f"line {line}:{column} {msg}") + + +@dataclass(slots=True) +class _Call: + name: str + args: list[str] + + +@dataclass(slots=True) +class _LoweringState: + placeholder_values: Any + var_types: list[dict[str, InferredType]] + + def push_var_scope(self) -> None: + self.var_types.append({}) + + def pop_var_scope(self) -> None: + self.var_types.pop() + + def set_var_type(self, name: str, var_type: InferredType) -> None: + self.var_types[-1][name] = var_type + + def get_var_type(self, name: str) -> InferredType: + for scope in reversed(self.var_types): + if name in scope: + return scope[name] + return InferredType.UNKNOWN + + +_LIST_OR_MAP_BOOL_RE = re.compile(r"(? bool: + return _parse_subset_rust is not None + + +def try_parse_ael_rust( + expression: str, + placeholder_values: Any = None, +) -> Optional[FilterExpression]: + """Return a Rust-fast-path parse result, or ``None`` if unavailable/unsupported.""" + if _parse_subset_rust is None: + return None + try: + normalized = _parse_subset_rust(expression) + return _lower_normalized(normalized, placeholder_values) + except RustFastPathUnsupported: + return None + except AelParseException: + raise + except Exception: + return None + + +def _lower_normalized(normalized: str, placeholder_values: Any) -> FilterExpression: + state = _LoweringState(placeholder_values=placeholder_values, var_types=[{}]) + expr = _lower_expr(normalized, state) + if isinstance(expr, TypedExpr): + return expr.expr + if isinstance(expr, CDTPath): + return expr.to_expression() + if isinstance(expr, (DeferredBin, DeferredArithmetic)): + return _resolve_for_arithmetic(expr) + if isinstance(expr, FilterExpression): + return expr + raise RustFastPathUnsupported(f"unsupported normalized root: {normalized}") + + +def _lower_expr(text: str, state: _LoweringState) -> Any: + call = _parse_call(text) + name = call.name + args = call.args + + if name == "pyhex": + raw = _decode_pyhex(_expect_arity(name, args, 1)[0]) + return _parse_expr_via_python_visitor(raw, state.placeholder_values) + + if name == "int": + value = int(_expect_arity(name, args, 1)[0], 0) + return TypedExpr(FilterExpression.int_val(value), InferredType.INT, value=value) + if name == "float": + value = float(_expect_arity(name, args, 1)[0]) + return TypedExpr(FilterExpression.float_val(value), InferredType.FLOAT, value=value) + if name == "castint": + value = _cast_numeric_literal( + _expect_arity(name, args, 1)[0], + state, + InferredType.INT, + ) + return TypedExpr(FilterExpression.int_val(value), InferredType.INT, value=value) + if name == "castfloat": + value = _cast_numeric_literal( + _expect_arity(name, args, 1)[0], + state, + InferredType.FLOAT, + ) + return TypedExpr(FilterExpression.float_val(value), InferredType.FLOAT, value=value) + if name == "bool": + raw = _expect_arity(name, args, 1)[0].lower() + if raw not in {"true", "false"}: + raise RustFastPathUnsupported(raw) + value = raw == "true" + return TypedExpr(FilterExpression.bool_val(value), InferredType.BOOL, value=value) + if name == "str": + value = _unquote(_expect_arity(name, args, 1)[0]) + return TypedExpr(FilterExpression.string_val(value), InferredType.STRING, value=value) + if name == "ph": + raw = _expect_arity(name, args, 1)[0] + if state.placeholder_values is None: + raise AelParseException("Placeholder used but no placeholder values provided") + if not raw.startswith("?"): + raise RustFastPathUnsupported(raw) + return _placeholder_to_expr(state.placeholder_values.get(int(raw[1:]))) + if name == "path": + return _lower_path(_expect_arity(name, args, 1)[0], state) + if name == "list": + values = _parse_python_like_literal(_expect_arity(name, args, 1)[0]) + if not isinstance(values, list): + raise RustFastPathUnsupported("list literal") + return TypedExpr(FilterExpression.list_val(values), InferredType.LIST, values) + if name == "map": + value = _parse_python_like_literal(_expect_arity(name, args, 1)[0]) + if not isinstance(value, dict): + raise RustFastPathUnsupported("map literal") + return TypedExpr(FilterExpression.map_val(value), InferredType.MAP) + if name == "var": + var_name = _expect_arity(name, args, 1)[0] + expr = FilterExpression.var(var_name) + var_type = state.get_var_type(var_name) + if var_type != InferredType.UNKNOWN: + return TypedExpr(expr, var_type) + return expr + if name == "let": + if len(args) < 2: + raise RustFastPathUnsupported("let arity") + state.push_var_scope() + try: + definitions: list[FilterExpression] = [] + for definition_text in args[:-1]: + definition = _parse_call(definition_text) + if definition.name != "def" or len(definition.args) != 2: + raise RustFastPathUnsupported("let definition") + var_name = definition.args[0] + value_expr = _lower_expr(definition.args[1], state) + state.set_var_type(var_name, _get_type_hint(value_expr)) + definitions.append(FilterExpression.def_(var_name, _resolve_for_arithmetic(value_expr))) + action_expr = _resolve_for_arithmetic(_lower_expr(args[-1], state)) + finally: + state.pop_var_scope() + return FilterExpression.exp_let(definitions + [action_expr]) + if name == "when": + if len(args) < 3 or len(args) % 2 == 0: + raise RustFastPathUnsupported("when arity") + cond_exprs: list[FilterExpression] = [] + for index in range(0, len(args) - 1, 2): + condition = _resolve_for_arithmetic(_lower_expr(args[index], state)) + action = _resolve_for_arithmetic(_lower_expr(args[index + 1], state)) + cond_exprs.append(condition) + cond_exprs.append(action) + cond_exprs.append(_resolve_for_arithmetic(_lower_expr(args[-1], state))) + return FilterExpression.cond(cond_exprs) + if name == "def": + raise RustFastPathUnsupported("definition outside let") + + if name == "abs": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 1)] + if _contains_deferred(values[0]): + return DeferredArithmetic(ArithOp.ABS, values) + result_type = _get_type_hint(values[0]) + if result_type not in (InferredType.INT, InferredType.FLOAT): + result_type = InferredType.INT + return TypedExpr( + FilterExpression.num_abs(_resolve_for_arithmetic(values[0])), + result_type, + ) + if name == "ceil": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 1)] + return TypedExpr( + FilterExpression.num_ceil(_resolve_for_arithmetic(values[0], has_float=True)), + InferredType.FLOAT, + ) + if name == "floor": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 1)] + return TypedExpr( + FilterExpression.num_floor(_resolve_for_arithmetic(values[0], has_float=True)), + InferredType.FLOAT, + ) + if name == "log": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)] + return TypedExpr( + FilterExpression.num_log( + _resolve_for_arithmetic(values[0], has_float=True), + _resolve_for_arithmetic(values[1], has_float=True), + ), + InferredType.FLOAT, + ) + if name == "pow": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)] + return TypedExpr( + FilterExpression.num_pow( + _resolve_for_arithmetic(values[0], has_float=True), + _resolve_for_arithmetic(values[1], has_float=True), + ), + InferredType.FLOAT, + ) + if name == "max": + values = [_lower_expr(arg, state) for arg in args] + _validate_min_arg_count(name, values, 2) + if any(_contains_deferred(v) for v in values): + return DeferredArithmetic(ArithOp.MAX, values) + has_float = any(_is_float_context(v) for v in values) + result_type = InferredType.FLOAT if has_float else InferredType.INT + return TypedExpr( + FilterExpression.max([ + _resolve_for_arithmetic(v, has_float=has_float) for v in values + ]), + result_type, + ) + if name == "min": + values = [_lower_expr(arg, state) for arg in args] + _validate_min_arg_count(name, values, 2) + if any(_contains_deferred(v) for v in values): + return DeferredArithmetic(ArithOp.MIN, values) + has_float = any(_is_float_context(v) for v in values) + result_type = InferredType.FLOAT if has_float else InferredType.INT + return TypedExpr( + FilterExpression.min([ + _resolve_for_arithmetic(v, has_float=has_float) for v in values + ]), + result_type, + ) + if name == "countOneBits": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 1)] + return TypedExpr( + FilterExpression.int_count(_resolve_for_arithmetic(values[0], has_float=False)), + InferredType.INT, + ) + if name == "findBitLeft": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)] + return TypedExpr( + FilterExpression.int_lscan( + _resolve_for_arithmetic(values[0], has_float=False), + _resolve_for_arithmetic(values[1], has_float=False), + ), + InferredType.INT, + ) + if name == "findBitRight": + values = [_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)] + return TypedExpr( + FilterExpression.int_rscan( + _resolve_for_arithmetic(values[0], has_float=False), + _resolve_for_arithmetic(values[1], has_float=False), + ), + InferredType.INT, + ) + + if name in {"add", "sub", "mul", "div", "mod"}: + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + op = { + "add": ArithOp.ADD, + "sub": ArithOp.SUB, + "mul": ArithOp.MUL, + "div": ArithOp.DIV, + "mod": ArithOp.MOD, + }[name] + return _build_arithmetic(op, left, right) + + if name == "bitand": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_and([ + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ]) + if name == "bitor": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_or([ + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ]) + if name == "bitxor": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_xor([ + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ]) + if name == "bitnot": + value = _lower_expr(_expect_arity(name, args, 1)[0], state) + return FilterExpression.int_not(_resolve_for_arithmetic(value, has_float=False)) + if name == "lshift": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_lshift( + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ) + if name == "arshift": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_arshift( + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ) + if name == "rshift": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + return FilterExpression.int_rshift( + _resolve_for_arithmetic(left, has_float=False), + _resolve_for_arithmetic(right, has_float=False), + ) + + if name == "neg": + inner = _lower_expr(_expect_arity(name, args, 1)[0], state) + if isinstance(inner, TypedExpr) and inner.type_hint == InferredType.INT: + value = -int(inner.value if inner.value is not None else 0) + return TypedExpr(FilterExpression.int_val(value), InferredType.INT, value=value) + if isinstance(inner, TypedExpr) and inner.type_hint == InferredType.FLOAT: + value = -float(inner.value if inner.value is not None else 0.0) + return TypedExpr(FilterExpression.float_val(value), InferredType.FLOAT, value=value) + minus_one = TypedExpr(FilterExpression.int_val(-1), InferredType.INT, value=-1) + return DeferredArithmetic(ArithOp.MUL, [inner, minus_one]) + + if name == "pos": + return _lower_expr(_expect_arity(name, args, 1)[0], state) + + if name in {"eq", "ne", "gt", "ge", "lt", "le"}: + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + resolved_left, resolved_right = _resolve_for_comparison(left, right) + return { + "eq": FilterExpression.eq, + "ne": FilterExpression.ne, + "gt": FilterExpression.gt, + "ge": FilterExpression.ge, + "lt": FilterExpression.lt, + "le": FilterExpression.le, + }[name](resolved_left, resolved_right) + + if name == "in": + left, right = (_lower_expr(arg, state) for arg in _expect_arity(name, args, 2)) + resolved_list = _resolve_for_in_list(right) + left_hint = _get_type_hint(left) + if left_hint != InferredType.UNKNOWN: + if isinstance(right, TypedExpr) and right.type_hint == InferredType.LIST: + if isinstance(right.value, list) and right.value: + element_type = _infer_element_type(right.value) + if element_type != InferredType.UNKNOWN: + _validate_in_type_compatibility(left_hint, element_type) + resolved_value = _resolve_for_in_value(left, right) + return FilterExpression.list_get_by_value( + ListReturnType.EXISTS, + resolved_value, + resolved_list, + [], + ) + + if name == "and": + exprs = [_coerce_bool_expr(_lower_expr(arg, state)) for arg in args] + return FilterExpression.and_(exprs) + if name == "or": + exprs = [_coerce_bool_expr(_lower_expr(arg, state)) for arg in args] + return FilterExpression.or_(exprs) + if name == "not": + expr = _coerce_bool_expr(_lower_expr(_expect_arity(name, args, 1)[0], state)) + return FilterExpression.not_(expr) + if name == "exclusive": + parts = [_coerce_bool_expr(_lower_expr(arg, state)) for arg in args] + if len(parts) < 2: + raise RustFastPathUnsupported("exclusive arity") + result = parts[0] + for expr in parts[1:]: + result = FilterExpression.xor([result, expr]) + return result + + raise RustFastPathUnsupported(name) + + +def _coerce_bool_expr(expr: Any) -> FilterExpression: + if isinstance(expr, DeferredBin): + return expr.to_expression(InferredType.BOOL) + if isinstance(expr, DeferredArithmetic): + return expr.to_expression(InferredType.BOOL) + if isinstance(expr, TypedExpr): + return expr.expr + if isinstance(expr, FilterExpression): + return expr + raise RustFastPathUnsupported(f"bool coercion for {type(expr).__name__}") + + +def _lower_path(raw: str, state: _LoweringState) -> Any: + if raw.startswith("${"): + return _parse_expr_via_python_visitor(raw, state.placeholder_values) + if not raw.startswith("$."): + raise RustFastPathUnsupported(raw) + body = raw[2:] + if body == "deviceSize()" or body == "memorySize()" or body == "recordSize()": + return TypedExpr(FilterExpression.device_size(), InferredType.INT) + if body == "isTombstone()": + return TypedExpr(FilterExpression.is_tombstone(), InferredType.BOOL) + if body == "keyExists()": + return TypedExpr(FilterExpression.key_exists(), InferredType.BOOL) + if body == "lastUpdate()": + return TypedExpr(FilterExpression.last_update(), InferredType.INT) + if body == "sinceUpdate()": + return TypedExpr(FilterExpression.since_update(), InferredType.INT) + if body == "setName()": + return TypedExpr(FilterExpression.set_name(), InferredType.STRING) + if body == "ttl()": + return TypedExpr(FilterExpression.ttl(), InferredType.INT) + if body == "voidTime()": + return TypedExpr(FilterExpression.void_time(), InferredType.INT) + if body.startswith("digestModulo(") and body.endswith(")"): + value = int(body[len("digestModulo("):-1], 0) + return TypedExpr(FilterExpression.digest_modulo(value), InferredType.INT) + + if not body or "." in body or "[" in body or "]" in body or "(" in body or ")" in body: + return _parse_expr_via_python_visitor(raw, state.placeholder_values) + return DeferredBin(body) + + +def _placeholder_to_expr(value: Any) -> TypedExpr: + if isinstance(value, bool): + return TypedExpr(FilterExpression.bool_val(value), InferredType.BOOL) + if isinstance(value, int): + return TypedExpr(FilterExpression.int_val(value), InferredType.INT) + if isinstance(value, float): + return TypedExpr(FilterExpression.float_val(value), InferredType.FLOAT) + if isinstance(value, str): + return TypedExpr(FilterExpression.string_val(value), InferredType.STRING, value=value) + if isinstance(value, bytes): + return TypedExpr(FilterExpression.blob_val(list(value)), InferredType.UNKNOWN, value=value) + if isinstance(value, (list, tuple)): + vals = list(value) + return TypedExpr(FilterExpression.list_val(vals), InferredType.LIST, value=vals) + if isinstance(value, dict): + return TypedExpr(FilterExpression.map_val(value), InferredType.UNKNOWN, value=value) + raise AelParseException(f"Unsupported placeholder value type: {type(value).__name__}") + + +def _cast_numeric_literal(raw_expr: str, state: _LoweringState, target_type: InferredType) -> int | float: + parsed = _parse_numeric_literal_call(raw_expr) + if parsed is not None: + return int(parsed) if target_type == InferredType.INT else float(parsed) + + expr = _lower_expr(raw_expr, state) + if not isinstance(expr, TypedExpr): + raise RustFastPathUnsupported("operand cast") + if expr.type_hint not in {InferredType.INT, InferredType.FLOAT}: + raise RustFastPathUnsupported("operand cast") + if expr.value is None: + raise RustFastPathUnsupported("operand cast") + if target_type == InferredType.INT: + return int(expr.value) + if target_type == InferredType.FLOAT: + return float(expr.value) + raise RustFastPathUnsupported("operand cast") + + +def _parse_numeric_literal_call(text: str) -> int | float | None: + if text.startswith("int(") and text.endswith(")"): + return int(text[4:-1], 0) + if text.startswith("float(") and text.endswith(")"): + return float(text[6:-1]) + return None + + +def _expect_arity(name: str, args: list[str], expected: int) -> list[str]: + if len(args) != expected: + raise RustFastPathUnsupported(f"{name} expects {expected} args") + return args + + +def _parse_call(text: str) -> _Call: + open_idx = text.find("(") + if open_idx <= 0 or not text.endswith(")"): + raise RustFastPathUnsupported(text) + name = text[:open_idx] + inner = text[open_idx + 1:-1] + args = _split_top_level(inner) + return _Call(name=name, args=args) + + +def _split_top_level(text: str) -> list[str]: + if not text: + return [] + parts: list[str] = [] + depth_paren = 0 + depth_bracket = 0 + depth_brace = 0 + quote: Optional[str] = None + start = 0 + for i, ch in enumerate(text): + if quote is not None: + if ch == quote: + quote = None + continue + if ch in {"'", '"'}: + quote = ch + continue + if ch == "(": + depth_paren += 1 + continue + if ch == ")": + depth_paren -= 1 + continue + if ch == "[": + depth_bracket += 1 + continue + if ch == "]": + depth_bracket -= 1 + continue + if ch == "{": + depth_brace += 1 + continue + if ch == "}": + depth_brace -= 1 + continue + if ch == "," and depth_paren == 0 and depth_bracket == 0 and depth_brace == 0: + parts.append(text[start:i]) + start = i + 1 + parts.append(text[start:]) + return parts + + +def _parse_python_like_literal(text: str) -> Any: + try: + pythonish = _LIST_OR_MAP_BOOL_RE.sub( + lambda m: {"true": "True", "false": "False"}[m.group(0)], + text, + ) + return ast.literal_eval(pythonish) + except Exception as e: + raise RustFastPathUnsupported(f"literal parse failed: {text}") from e + + +def _decode_pyhex(text: str) -> str: + try: + return bytes.fromhex(text).decode("utf-8") + except Exception as e: + raise RustFastPathUnsupported(f"invalid opaque payload: {text}") from e + + +def _parse_expr_via_python_visitor(expression: str, placeholder_values: Any) -> Any: + try: + input_stream = InputStream(expression) + lexer = ConditionLexer(input_stream) + lexer.removeErrorListeners() + error_listener = _RustFastPathErrorListener() + lexer.addErrorListener(error_listener) + + token_stream = CommonTokenStream(lexer) + parser = ConditionParser(token_stream) + parser.removeErrorListeners() + parser.addErrorListener(error_listener) + + parse_tree = parser.parse() + visitor = ExpressionConditionVisitor(placeholder_values=placeholder_values) + result = visitor.visit(parse_tree.expression()) + if result is None: + raise RustFastPathUnsupported(expression) + return result + except AelParseException: + raise + except RustFastPathUnsupported: + raise + except Exception as e: + raise RustFastPathUnsupported(expression) from e diff --git a/aerospike_sdk/ael/parser.py b/aerospike_sdk/ael/parser.py index 36d5b53..a8d6cf4 100644 --- a/aerospike_sdk/ael/parser.py +++ b/aerospike_sdk/ael/parser.py @@ -56,6 +56,7 @@ MapValuePart, ) from aerospike_sdk.ael.filter_gen import FilterGenerator, IndexContext, ParseResult +from aerospike_sdk.ael._rust_fastpath import try_parse_ael_rust class _AELParseErrorListener(ErrorListener): @@ -163,6 +164,9 @@ def parse_ael(expression: str, *args: Any) -> FilterExpression: if _parser is None: _parser = AELParser() placeholder_values = PlaceholderValues(*args) if args else None + rust_result = try_parse_ael_rust(expression, placeholder_values) + if rust_result is not None: + return rust_result return _parser.parse(expression, placeholder_values) diff --git a/docs/benchmarks/ael-rust-fastpath-results.md b/docs/benchmarks/ael-rust-fastpath-results.md new file mode 100644 index 0000000..bc25c3f --- /dev/null +++ b/docs/benchmarks/ael-rust-fastpath-results.md @@ -0,0 +1,68 @@ +# AEL Rust Fast Path Results + +This benchmark note captures the current behavior and performance of the Rust-backed AEL fast path after native lowering for `let(...)`, `when(...)`, and variable references. + +## Verification + +- `cargo test --manifest-path rust/ael_rust_ext/Cargo.toml --target-dir /tmp/ael_rust_ext-target`: `5 passed` +- `pytest tests/unit -q`: `1581 passed` +- `pytest tests/integration -q`: `979 passed, 37 skipped, 2 xfailed` + +## Parser Benchmark + +![AEL parser throughput](ael-rust-parser-throughput.svg) + +![AEL parser latency](ael-rust-parser-latency.svg) + +### Repo expression corpus + +- Raw corpus size: `868` +- Supported by Rust fast path: `750` +- Raw skips: `118` +- Skip classification: `28` placeholder-only cases with no values supplied, `90` expressions that Python itself rejects +- Real valid-expression fallbacks: `0` +- Mismatches vs Python: `0` +- Python parser throughput: `4,798.66 expr/s` +- Rust fast path throughput: `12,988.76 expr/s` +- Python parser time: `208.39 us/expr` +- Rust fast path time: `76.99 us/expr` +- Speedup: `2.71x` + +### Literal `parse_ael(...)` callsites + +- Corpus size: `596` +- Supported by Rust fast path: `596/596` +- Mismatches vs Python: `0` +- Python parser throughput: `4,496.22 expr/s` +- Rust fast path throughput: `9,842.60 expr/s` +- Python parser time: `222.41 us/expr` +- Rust fast path time: `101.60 us/expr` +- Speedup: `2.19x` + +## Live Workload Benchmark + +Environment: + +- Aerospike server: Docker `aerospike/aerospike-server:8.1.0.2` +- Address: `127.0.0.1:3000` +- Workload: `RU,50` +- Command: `python -m benchmarks.benchmark -k 100 -z 4 -w RU,50 -d 3 --warmup 0 --cooldown 0` + +Results: + +- Total TPS: `8389` +- Read TPS: `4161` +- Write TPS: `4228` +- Latency: `p50=0.4ms p90=0.9ms p99=1.4ms p99.9=1.6ms max=1.8ms` +- Peak RSS: `48.8 MB` + +Comparison run: + +- PAC async: `9016 TPS`, `p99=1.4ms`, `RSS=42.3MB` +- PSDK async: `8486 TPS`, `p99=1.4ms`, `RSS=49.0MB` +- PSDK sim-sync: `3007 TPS`, `p99=0.7ms`, `RSS=49.2MB` + +Interpretation: + +- The parser work is materially faster. +- The live `RU,50` benchmark does not show a major end-to-end TPS change because it is primarily request/response bound rather than parser bound. diff --git a/docs/benchmarks/ael-rust-parser-latency.svg b/docs/benchmarks/ael-rust-parser-latency.svg new file mode 100644 index 0000000..02b6fd9 --- /dev/null +++ b/docs/benchmarks/ael-rust-parser-latency.svg @@ -0,0 +1,27 @@ + + AEL Rust Parser Latency + Microseconds per expression for Python and Rust AEL parsing on the repo corpus and literal parse_ael callsites. + + + AEL Parser Latency + Lower is better. Measured as microseconds spent per expression. + + + Python parser + + Rust fast path + + Repo expression corpus + Average time per successfully handled expression + + Python: 208.4 us + + Rust: 77.0 us + + Literal parse_ael(...) callsites + Average time per supported callsite expression + + Python: 222.4 us + + Rust: 101.6 us + diff --git a/docs/benchmarks/ael-rust-parser-throughput.svg b/docs/benchmarks/ael-rust-parser-throughput.svg new file mode 100644 index 0000000..4f7c2e1 --- /dev/null +++ b/docs/benchmarks/ael-rust-parser-throughput.svg @@ -0,0 +1,29 @@ + + AEL Rust Parser Throughput + Expressions per second for Python and Rust AEL parsing on the repo corpus and literal parse_ael callsites. + + + AEL Parser Throughput + Higher is better. Measured as expressions handled per second. + + + Python parser + + Rust fast path + + Repo expression corpus + 750 valid expressions, 0 mismatches, 0 real fallbacks + + Python: 4,799 expr/s + + Rust: 12,989 expr/s + 2.71x faster + + Literal parse_ael(...) callsites + 596/596 supported, 0 mismatches + + Python: 4,496 expr/s + + Rust: 9,843 expr/s + 2.19x faster + diff --git a/rust/ael_rust_ext/Cargo.toml b/rust/ael_rust_ext/Cargo.toml new file mode 100644 index 0000000..28fc146 --- /dev/null +++ b/rust/ael_rust_ext/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "ael_rust_ext" +version = "0.1.0" +edition = "2021" + +[lib] +name = "ael_rust_ext" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.21.2", features = ["extension-module"] } diff --git a/rust/ael_rust_ext/src/lib.rs b/rust/ael_rust_ext/src/lib.rs new file mode 100644 index 0000000..7acf5e1 --- /dev/null +++ b/rust/ael_rust_ext/src/lib.rs @@ -0,0 +1,1013 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq)] +enum TokenKind { + Eof, + LParen, + RParen, + Comma, + Assign, + Arrow, + CastInt, + CastFloat, + Amp, + Pipe, + Caret, + Plus, + Minus, + Star, + Slash, + Percent, + Tilde, + Gt, + Ge, + Lt, + Le, + Eq, + Ne, + Pow, + LShift, + ArShift, + RShift, + Int, + Float, + Bool, + String, + Placeholder, + Variable, + Path, + List, + Map, + And, + Or, + Not, + Exclusive, + In, + Let, + When, + Then, + Default, + Identifier, +} + +#[derive(Clone, Debug)] +struct Token { + kind: TokenKind, + text: String, +} + +fn compact_source(text: &str) -> String { + let bytes = text.as_bytes(); + let mut out = String::with_capacity(text.len()); + let mut i = 0usize; + let mut quote: Option = None; + while i < bytes.len() { + let ch = bytes[i]; + if let Some(q) = quote { + out.push(ch as char); + if ch == q { + quote = None; + } + i += 1; + continue; + } + if ch == b'\'' || ch == b'"' { + quote = Some(ch); + out.push(ch as char); + i += 1; + continue; + } + if !(ch as char).is_whitespace() { + out.push(ch as char); + } + i += 1; + } + out +} + +struct Tokenizer<'a> { + text: &'a str, + bytes: &'a [u8], + pos: usize, +} + +impl<'a> Tokenizer<'a> { + fn new(text: &'a str) -> Self { + Self { + text, + bytes: text.as_bytes(), + pos: 0, + } + } + + fn peek(&self, offset: usize) -> Option { + self.bytes.get(self.pos + offset).copied() + } + + fn advance(&mut self, count: usize) { + self.pos += count; + } + + fn slice(&self, start: usize, end: usize) -> String { + self.text[start..end].to_string() + } + + fn skip_ws(&mut self) { + while matches!(self.peek(0), Some(ch) if (ch as char).is_whitespace()) { + self.advance(1); + } + } + + fn scan_quoted(&mut self, quote: u8) -> Result { + let start = self.pos; + self.advance(1); + while self.pos < self.bytes.len() { + if self.peek(0) == Some(quote) { + self.advance(1); + return Ok(self.slice(start, self.pos)); + } + self.advance(1); + } + Err("unterminated string".to_string()) + } + + fn scan_balanced(&mut self, open_ch: u8, close_ch: u8) -> Result { + let start = self.pos; + let mut depth = 0i32; + while self.pos < self.bytes.len() { + let ch = self.peek(0).unwrap(); + if ch == b'\'' || ch == b'"' { + self.scan_quoted(ch)?; + continue; + } + if ch == open_ch { + depth += 1; + } else if ch == close_ch { + depth -= 1; + if depth == 0 { + self.advance(1); + return Ok(self.slice(start, self.pos)); + } + } + self.advance(1); + } + Err("unterminated balanced token".to_string()) + } + + fn scan_number(&mut self) -> Result { + let start = self.pos; + if self.peek(0) == Some(b'0') && matches!(self.peek(1), Some(b'x' | b'X')) { + self.advance(2); + let ch = self + .peek(0) + .ok_or_else(|| "invalid hex literal".to_string())?; + if !(ch.is_ascii_digit() || matches!(ch, b'a'..=b'f' | b'A'..=b'F')) { + return Err("invalid hex literal".to_string()); + } + while let Some(ch) = self.peek(0) { + if !(ch.is_ascii_digit() || matches!(ch, b'a'..=b'f' | b'A'..=b'F')) { + break; + } + self.advance(1); + } + return Ok(Token { + kind: TokenKind::Int, + text: self.slice(start, self.pos), + }); + } + if self.peek(0) == Some(b'0') && matches!(self.peek(1), Some(b'b' | b'B')) { + self.advance(2); + let ch = self + .peek(0) + .ok_or_else(|| "invalid binary literal".to_string())?; + if ch != b'0' && ch != b'1' { + return Err("invalid binary literal".to_string()); + } + while let Some(ch) = self.peek(0) { + if ch != b'0' && ch != b'1' { + break; + } + self.advance(1); + } + return Ok(Token { + kind: TokenKind::Int, + text: self.slice(start, self.pos), + }); + } + if self.peek(0) == Some(b'.') { + self.advance(1); + while matches!(self.peek(0), Some(ch) if ch.is_ascii_digit()) { + self.advance(1); + } + return Ok(Token { + kind: TokenKind::Float, + text: self.slice(start, self.pos), + }); + } + while matches!(self.peek(0), Some(ch) if ch.is_ascii_digit()) { + self.advance(1); + } + if self.peek(0) == Some(b'.') && matches!(self.peek(1), Some(ch) if ch.is_ascii_digit()) { + self.advance(1); + while matches!(self.peek(0), Some(ch) if ch.is_ascii_digit()) { + self.advance(1); + } + return Ok(Token { + kind: TokenKind::Float, + text: self.slice(start, self.pos), + }); + } + Ok(Token { + kind: TokenKind::Int, + text: self.slice(start, self.pos), + }) + } + + fn scan_identifier(&mut self) -> Result { + let start = self.pos; + while matches!(self.peek(0), Some(ch) if (ch as char).is_ascii_alphanumeric() || ch == b'_') + { + self.advance(1); + } + let text = self.slice(start, self.pos); + let lower = text.to_ascii_lowercase(); + let kind = if lower == "in" { + TokenKind::In + } else if lower == "and" { + TokenKind::And + } else if lower == "or" { + TokenKind::Or + } else if lower == "not" { + TokenKind::Not + } else if lower == "exclusive" { + TokenKind::Exclusive + } else if lower == "true" || lower == "false" { + TokenKind::Bool + } else if lower == "when" { + TokenKind::When + } else if lower == "let" { + TokenKind::Let + } else if lower == "then" { + TokenKind::Then + } else if lower == "default" { + TokenKind::Default + } else { + TokenKind::Identifier + }; + Ok(Token { kind, text }) + } + + fn scan_variable(&mut self) -> Result { + self.advance(2); + if !matches!(self.peek(0), Some(ch) if (ch as char).is_ascii_alphabetic() || ch == b'_') { + return Err("invalid variable".to_string()); + } + let start = self.pos; + while matches!(self.peek(0), Some(ch) if (ch as char).is_ascii_alphanumeric() || ch == b'_') + { + self.advance(1); + } + let text = self.slice(start, self.pos); + if self.peek(0) != Some(b'}') { + return Err("invalid variable".to_string()); + } + self.advance(1); + Ok(Token { + kind: TokenKind::Variable, + text, + }) + } + + fn scan_placeholder(&mut self) -> Result { + let start = self.pos; + self.advance(1); + if !matches!(self.peek(0), Some(ch) if ch.is_ascii_digit()) { + return Err("invalid placeholder".to_string()); + } + while matches!(self.peek(0), Some(ch) if ch.is_ascii_digit()) { + self.advance(1); + } + Ok(Token { + kind: TokenKind::Placeholder, + text: self.slice(start, self.pos), + }) + } + + fn scan_collection(&mut self) -> Result { + let text = if self.peek(0) == Some(b'[') { + Token { + kind: TokenKind::List, + text: compact_source(&self.scan_balanced(b'[', b']')?), + } + } else { + Token { + kind: TokenKind::Map, + text: compact_source(&self.scan_balanced(b'{', b'}')?), + } + }; + Ok(text) + } + + fn scan_path(&mut self) -> Result { + let start = self.pos; + let mut paren = 0i32; + let mut bracket = 0i32; + let mut brace = 0i32; + while self.pos < self.bytes.len() { + let ch = self.peek(0).unwrap(); + if ch == b'\'' || ch == b'"' { + self.scan_quoted(ch)?; + continue; + } + if ch == b'(' { + paren += 1; + self.advance(1); + continue; + } + if ch == b')' { + if paren == 0 && bracket == 0 && brace == 0 { + break; + } + paren -= 1; + self.advance(1); + continue; + } + if ch == b'[' { + bracket += 1; + self.advance(1); + continue; + } + if ch == b']' { + bracket -= 1; + self.advance(1); + continue; + } + if ch == b'{' { + brace += 1; + self.advance(1); + continue; + } + if ch == b'}' { + brace -= 1; + self.advance(1); + continue; + } + if paren == 0 && bracket == 0 && brace == 0 { + if ch == b',' + || ch == b'+' + || ch == b'-' + || ch == b'*' + || ch == b'/' + || ch == b'%' + || ch == b'&' + || ch == b'|' + || ch == b'^' + { + break; + } + if ch == b'>' || ch == b'<' { + break; + } + if ch == b'=' || ch == b'!' { + break; + } + if (ch as char).is_whitespace() { + break; + } + } + self.advance(1); + } + Ok(Token { + kind: TokenKind::Path, + text: compact_source(&self.slice(start, self.pos)), + }) + } + + fn next_token(&mut self) -> Result { + self.skip_ws(); + let ch = match self.peek(0) { + Some(ch) => ch, + None => { + return Ok(Token { + kind: TokenKind::Eof, + text: String::new(), + }) + } + }; + let rest = &self.text[self.pos..]; + let three = if self.pos + 2 < self.bytes.len() { + &self.text[self.pos..self.pos + 3] + } else { + "" + }; + let two = if self.pos + 1 < self.bytes.len() { + &self.text[self.pos..self.pos + 2] + } else { + "" + }; + match three { + ">>>" => { + self.advance(3); + return Ok(Token { + kind: TokenKind::RShift, + text: three.to_string(), + }); + } + _ => {} + } + match two { + "<<" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::LShift, + text: two.to_string(), + }); + } + ">>" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::ArShift, + text: two.to_string(), + }); + } + "**" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Pow, + text: two.to_string(), + }); + } + ">=" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Ge, + text: two.to_string(), + }); + } + "<=" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Le, + text: two.to_string(), + }); + } + "==" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Eq, + text: two.to_string(), + }); + } + "!=" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Ne, + text: two.to_string(), + }); + } + "=>" => { + self.advance(2); + return Ok(Token { + kind: TokenKind::Arrow, + text: two.to_string(), + }); + } + _ => {} + } + match ch { + b'(' => { + self.advance(1); + Ok(Token { + kind: TokenKind::LParen, + text: "(".to_string(), + }) + } + b')' => { + self.advance(1); + Ok(Token { + kind: TokenKind::RParen, + text: ")".to_string(), + }) + } + b',' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Comma, + text: ",".to_string(), + }) + } + b'=' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Assign, + text: "=".to_string(), + }) + } + b'&' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Amp, + text: "&".to_string(), + }) + } + b'|' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Pipe, + text: "|".to_string(), + }) + } + b'^' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Caret, + text: "^".to_string(), + }) + } + b'+' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Plus, + text: "+".to_string(), + }) + } + b'-' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Minus, + text: "-".to_string(), + }) + } + b'*' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Star, + text: "*".to_string(), + }) + } + b'/' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Slash, + text: "/".to_string(), + }) + } + b'%' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Percent, + text: "%".to_string(), + }) + } + b'~' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Tilde, + text: "~".to_string(), + }) + } + b'>' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Gt, + text: ">".to_string(), + }) + } + b'<' => { + self.advance(1); + Ok(Token { + kind: TokenKind::Lt, + text: "<".to_string(), + }) + } + b'.' if rest.starts_with(".asInt()") => { + self.advance(".asInt()".len()); + Ok(Token { + kind: TokenKind::CastInt, + text: ".asInt()".to_string(), + }) + } + b'.' if rest.starts_with(".asFloat()") => { + self.advance(".asFloat()".len()); + Ok(Token { + kind: TokenKind::CastFloat, + text: ".asFloat()".to_string(), + }) + } + b'\'' | b'"' => { + let text = self.scan_quoted(ch)?; + Ok(Token { + kind: TokenKind::String, + text, + }) + } + b'?' => self.scan_placeholder(), + b'$' if rest.starts_with("${") => self.scan_variable(), + b'$' => self.scan_path(), + b'[' | b'{' => self.scan_collection(), + _ if ch.is_ascii_digit() + || (ch == b'.' && matches!(self.peek(1), Some(d) if d.is_ascii_digit())) => + { + self.scan_number() + } + _ if (ch as char).is_ascii_alphabetic() => self.scan_identifier(), + _ => Err(format!("unexpected character {:?}", ch as char)), + } + } +} + +struct Parser<'a> { + tokenizer: Tokenizer<'a>, + current: Token, +} + +impl<'a> Parser<'a> { + fn new(text: &'a str) -> Result { + let mut tokenizer = Tokenizer::new(text); + let current = tokenizer.next_token()?; + Ok(Self { tokenizer, current }) + } + + fn advance(&mut self) -> Result<(), String> { + self.current = self.tokenizer.next_token()?; + Ok(()) + } + + fn eat(&mut self, kind: TokenKind) -> Result { + if self.current.kind != kind { + return Err(format!("expected {:?}, got {:?}", kind, self.current.kind)); + } + let tok = self.current.clone(); + self.advance()?; + Ok(tok) + } + + fn parse(&mut self) -> Result { + let out = self.parse_or()?; + self.eat(TokenKind::Eof)?; + Ok(out) + } + + fn parse_or(&mut self) -> Result { + let mut parts = vec![self.parse_and()?]; + while self.current.kind == TokenKind::Or { + self.eat(TokenKind::Or)?; + parts.push(self.parse_and()?); + } + if parts.len() == 1 { + Ok(parts.remove(0)) + } else { + Ok(format!("or({})", parts.join(","))) + } + } + + fn parse_and(&mut self) -> Result { + let mut parts = vec![self.parse_cmp()?]; + while self.current.kind == TokenKind::And { + self.eat(TokenKind::And)?; + parts.push(self.parse_cmp()?); + } + if parts.len() == 1 { + Ok(parts.remove(0)) + } else { + Ok(format!("and({})", parts.join(","))) + } + } + + fn parse_cmp(&mut self) -> Result { + let left = self.parse_bitwise()?; + match self.current.kind { + TokenKind::Gt + | TokenKind::Ge + | TokenKind::Lt + | TokenKind::Le + | TokenKind::Eq + | TokenKind::Ne + | TokenKind::In => { + let op = self.current.kind.clone(); + self.advance()?; + let right = self.parse_bitwise()?; + let name = match op { + TokenKind::Gt => "gt", + TokenKind::Ge => "ge", + TokenKind::Lt => "lt", + TokenKind::Le => "le", + TokenKind::Eq => "eq", + TokenKind::Ne => "ne", + TokenKind::In => "in", + _ => unreachable!(), + }; + Ok(format!("{name}({left},{right})")) + } + _ => Ok(left), + } + } + + fn parse_bitwise(&mut self) -> Result { + let mut node = self.parse_shift()?; + loop { + match self.current.kind { + TokenKind::Amp => { + self.eat(TokenKind::Amp)?; + let right = self.parse_shift()?; + node = format!("bitand({node},{right})"); + } + TokenKind::Pipe => { + self.eat(TokenKind::Pipe)?; + let right = self.parse_shift()?; + node = format!("bitor({node},{right})"); + } + TokenKind::Caret => { + self.eat(TokenKind::Caret)?; + let right = self.parse_shift()?; + node = format!("bitxor({node},{right})"); + } + _ => return Ok(node), + } + } + } + + fn parse_shift(&mut self) -> Result { + let mut node = self.parse_add()?; + loop { + match self.current.kind { + TokenKind::LShift => { + self.eat(TokenKind::LShift)?; + let right = self.parse_add()?; + node = format!("lshift({node},{right})"); + } + TokenKind::ArShift => { + self.eat(TokenKind::ArShift)?; + let right = self.parse_add()?; + node = format!("arshift({node},{right})"); + } + TokenKind::RShift => { + self.eat(TokenKind::RShift)?; + let right = self.parse_add()?; + node = format!("rshift({node},{right})"); + } + _ => return Ok(node), + } + } + } + + fn parse_add(&mut self) -> Result { + let mut node = self.parse_mul()?; + loop { + match self.current.kind { + TokenKind::Plus => { + self.eat(TokenKind::Plus)?; + let right = self.parse_mul()?; + node = format!("add({node},{right})"); + } + TokenKind::Minus => { + self.eat(TokenKind::Minus)?; + let right = self.parse_mul()?; + node = format!("sub({node},{right})"); + } + _ => return Ok(node), + } + } + } + + fn parse_mul(&mut self) -> Result { + let mut node = self.parse_power()?; + loop { + match self.current.kind { + TokenKind::Star => { + self.eat(TokenKind::Star)?; + let right = self.parse_power()?; + node = format!("mul({node},{right})"); + } + TokenKind::Slash => { + self.eat(TokenKind::Slash)?; + let right = self.parse_power()?; + node = format!("div({node},{right})"); + } + TokenKind::Percent => { + self.eat(TokenKind::Percent)?; + let right = self.parse_power()?; + node = format!("mod({node},{right})"); + } + _ => return Ok(node), + } + } + } + + fn parse_power(&mut self) -> Result { + let node = self.parse_unary()?; + if self.current.kind == TokenKind::Pow { + self.eat(TokenKind::Pow)?; + let right = self.parse_power()?; + Ok(format!("pow({node},{right})")) + } else { + Ok(node) + } + } + + fn parse_unary(&mut self) -> Result { + match self.current.kind { + TokenKind::Plus => { + self.eat(TokenKind::Plus)?; + Ok(format!("pos({})", self.parse_unary()?)) + } + TokenKind::Minus => { + self.eat(TokenKind::Minus)?; + Ok(format!("neg({})", self.parse_unary()?)) + } + TokenKind::Tilde => { + self.eat(TokenKind::Tilde)?; + Ok(format!("bitnot({})", self.parse_unary()?)) + } + _ => self.parse_primary(), + } + } + + fn parse_primary(&mut self) -> Result { + match self.current.kind.clone() { + TokenKind::LParen => { + self.eat(TokenKind::LParen)?; + let inner = self.parse_or()?; + self.eat(TokenKind::RParen)?; + Ok(inner) + } + TokenKind::Not => { + self.eat(TokenKind::Not)?; + self.eat(TokenKind::LParen)?; + let inner = self.parse_or()?; + self.eat(TokenKind::RParen)?; + Ok(format!("not({inner})")) + } + TokenKind::Exclusive => { + self.eat(TokenKind::Exclusive)?; + self.eat(TokenKind::LParen)?; + let mut parts = vec![self.parse_or()?]; + self.eat(TokenKind::Comma)?; + parts.push(self.parse_or()?); + while self.current.kind == TokenKind::Comma { + self.eat(TokenKind::Comma)?; + parts.push(self.parse_or()?); + } + self.eat(TokenKind::RParen)?; + Ok(format!("exclusive({})", parts.join(","))) + } + TokenKind::Let => self.parse_let_expression(), + TokenKind::When => self.parse_when_expression(), + TokenKind::Identifier => { + let name = self.eat(TokenKind::Identifier)?.text; + self.eat(TokenKind::LParen)?; + let mut args = Vec::new(); + if self.current.kind != TokenKind::RParen { + args.push(self.parse_or()?); + while self.current.kind == TokenKind::Comma { + self.eat(TokenKind::Comma)?; + args.push(self.parse_or()?); + } + } + self.eat(TokenKind::RParen)?; + Ok(format!("{name}({})", args.join(","))) + } + TokenKind::Int => self.parse_number_with_optional_cast(TokenKind::Int), + TokenKind::Float => self.parse_number_with_optional_cast(TokenKind::Float), + TokenKind::Bool => Ok(format!("bool({})", self.eat(TokenKind::Bool)?.text)), + TokenKind::String => Ok(format!("str({})", self.eat(TokenKind::String)?.text)), + TokenKind::Placeholder => Ok(format!("ph({})", self.eat(TokenKind::Placeholder)?.text)), + TokenKind::Variable => Ok(format!("var({})", self.eat(TokenKind::Variable)?.text)), + TokenKind::Path => Ok(format!("path({})", self.eat(TokenKind::Path)?.text)), + TokenKind::List => Ok(format!("list({})", self.eat(TokenKind::List)?.text)), + TokenKind::Map => Ok(format!("map({})", self.eat(TokenKind::Map)?.text)), + other => Err(format!("unexpected token {:?}", other)), + } + } + + fn parse_let_expression(&mut self) -> Result { + self.eat(TokenKind::Let)?; + self.eat(TokenKind::LParen)?; + let mut parts = vec![self.parse_variable_definition()?]; + while self.current.kind == TokenKind::Comma { + self.eat(TokenKind::Comma)?; + parts.push(self.parse_variable_definition()?); + } + self.eat(TokenKind::RParen)?; + self.eat(TokenKind::Then)?; + self.eat(TokenKind::LParen)?; + parts.push(self.parse_or()?); + self.eat(TokenKind::RParen)?; + Ok(format!("let({})", parts.join(","))) + } + + fn parse_variable_definition(&mut self) -> Result { + let name = self.eat(TokenKind::Identifier)?.text; + self.eat(TokenKind::Assign)?; + let value = self.parse_or()?; + Ok(format!("def({name},{value})")) + } + + fn parse_when_expression(&mut self) -> Result { + self.eat(TokenKind::When)?; + self.eat(TokenKind::LParen)?; + let mut parts = Vec::new(); + loop { + if self.current.kind == TokenKind::Default { + self.eat(TokenKind::Default)?; + self.eat(TokenKind::Arrow)?; + parts.push(self.parse_or()?); + self.eat(TokenKind::RParen)?; + return Ok(format!("when({})", parts.join(","))); + } + let condition = self.parse_or()?; + self.eat(TokenKind::Arrow)?; + let action = self.parse_or()?; + parts.push(condition); + parts.push(action); + self.eat(TokenKind::Comma)?; + } + } + + fn parse_number_with_optional_cast(&mut self, kind: TokenKind) -> Result { + let node = match kind { + TokenKind::Int => format!("int({})", self.eat(TokenKind::Int)?.text), + TokenKind::Float => format!("float({})", self.eat(TokenKind::Float)?.text), + _ => unreachable!(), + }; + match self.current.kind { + TokenKind::CastInt => { + self.eat(TokenKind::CastInt)?; + Ok(format!("castint({node})")) + } + TokenKind::CastFloat => { + self.eat(TokenKind::CastFloat)?; + Ok(format!("castfloat({node})")) + } + _ => Ok(node), + } + } +} + +#[pyfunction] +fn parse_subset_rust(text: &str) -> PyResult { + Parser::new(text) + .and_then(|mut parser| parser.parse()) + .map_err(PyValueError::new_err) +} + +#[pymodule] +fn ael_rust_ext(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(parse_subset_rust, m)?)?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::Parser; + + fn parse(text: &str) -> String { + Parser::new(text) + .and_then(|mut parser| parser.parse()) + .expect("parser should succeed") + } + + #[test] + fn parses_generic_function_calls() { + assert_eq!( + parse("max($.a, $.b) > 10"), + "gt(max(path($.a),path($.b)),int(10))", + ); + } + + #[test] + fn parses_bitwise_and_scan_functions() { + assert_eq!( + parse("countOneBits($.a & $.b) > 3"), + "gt(countOneBits(bitand(path($.a),path($.b))),int(3))", + ); + } + + #[test] + fn parses_when_as_opaque_operand() { + let expr = "when($.who == 1 => \"bob\", default => \"other\") == $.state"; + assert_eq!( + parse(expr), + "eq(when(eq(path($.who),int(1)),str(\"bob\"),str(\"other\")),path($.state))", + ); + } + + #[test] + fn parses_let_as_opaque_operand() { + let expr = "let (x = 1, y = ${x} + 1) then (${x} + ${y}) > 2"; + assert_eq!( + parse(expr), + "gt(let(def(x,int(1)),def(y,add(var(x),int(1))),add(var(x),var(y))),int(2))", + ); + } + + #[test] + fn parses_operand_casts() { + assert_eq!( + parse("28.asFloat() == 28.0"), + "eq(castfloat(int(28)),float(28.0))", + ); + assert_eq!(parse(".37.asInt() == 0"), "eq(castint(float(.37)),int(0))",); + } +} diff --git a/tests/unit/ael_rust_fastpath_test.py b/tests/unit/ael_rust_fastpath_test.py new file mode 100644 index 0000000..ff19f8a --- /dev/null +++ b/tests/unit/ael_rust_fastpath_test.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import ast +from pathlib import Path +from typing import Any + +import pytest + +from aerospike_sdk.ael import parser as parser_mod +from aerospike_sdk.ael._rust_fastpath import rust_fastpath_available, try_parse_ael_rust +from aerospike_sdk.ael.exceptions import AelParseException +from tools.ael_rust_benchmark import extract_repo_expressions + + +def _base64(expr) -> str: + return expr.base64() + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _call_name(node: ast.AST) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None + + +def _literal_parse_call_cases() -> list[tuple[str, str, tuple[Any, ...]]]: + root = _repo_root() + cases: list[tuple[str, str, tuple[Any, ...]]] = [] + seen: set[tuple[str, str]] = set() + + for path in list((root / "tests").glob("**/*.py")) + list((root / "aerospike_sdk").glob("**/*.py")): + try: + tree = ast.parse(path.read_text(), filename=str(path)) + except Exception: + continue + + rel = path.relative_to(root) + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + if _call_name(node.func) != "parse_ael": + continue + if not node.args: + continue + first = node.args[0] + if not isinstance(first, ast.Constant) or not isinstance(first.value, str): + continue + + try: + values = tuple(ast.literal_eval(arg) for arg in node.args[1:]) + except Exception: + continue + + expr = first.value + key = (expr, repr(values)) + if key in seen: + continue + seen.add(key) + cases.append((f"{rel}:{node.lineno}", expr, values)) + + return cases + + +class TestAelRustFastPath: + def test_simple_comparison_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "gt(path($.age),int(18))", + ) + baseline = parser_mod.AELParser.parse("$.age > 18") + fast = parser_mod.parse_ael("$.age > 18") + assert _base64(fast) == _base64(baseline) + + def test_boolean_chain_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "and(gt(path($.age),int(18)),eq(path($.active),bool(true)))", + ) + expr = "$.age > 18 and $.active == true" + baseline = parser_mod.AELParser.parse(expr) + fast = parser_mod.parse_ael(expr) + assert _base64(fast) == _base64(baseline) + + def test_placeholder_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "eq(path($.age),ph(?0))", + ) + placeholder_values = parser_mod.PlaceholderValues(21) + baseline = parser_mod.AELParser.parse("$.age == ?0", placeholder_values) + fast = parser_mod.parse_ael("$.age == ?0", 21) + assert _base64(fast) == _base64(baseline) + + def test_in_list_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "in(path($.country),list(['US','EG']))", + ) + expr = "$.country in ['US', 'EG']" + baseline = parser_mod.AELParser.parse(expr) + fast = parser_mod.parse_ael(expr) + assert _base64(fast) == _base64(baseline) + + def test_in_with_non_list_right_operand_matches_python_error(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "in(path($.name),str(\"Bob\"))", + ) + expr = '$.name in "Bob"' + with pytest.raises(AelParseException, match="IN operation requires a List"): + parser_mod.parse_ael(expr) + + def test_negative_numeric_literal_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "eq(path($.x),neg(int(5)))", + ) + expr = "$.x == -5" + baseline = parser_mod.AELParser.parse(expr) + fast = parser_mod.parse_ael(expr) + assert _base64(fast) == _base64(baseline) + + def test_double_negative_numeric_literal_matches_python_parser(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "eq(neg(neg(int(5))),int(5))", + ) + expr = "--5 == 5" + baseline = parser_mod.AELParser.parse(expr) + fast = parser_mod.parse_ael(expr) + assert _base64(fast) == _base64(baseline) + + def test_complex_cdt_path_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "eq(path($.listBin1.[0].get(type:INT)),int(100))", + ) + expr = "$.listBin1.[0].get(type: INT) == 100" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_standalone_cdt_path_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "path($.mapBin1.{a-c})", + ) + expr = "$.mapBin1.{a-c}" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_numeric_operand_cast_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "eq(castfloat(int(28)),float(28.0))", + ) + expr = "28.asFloat() == 28.0" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_large_negative_operand_cast_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "lt(neg(castfloat(int(9223372036854775808))),float(0.0))", + ) + expr = "-9223372036854775808.asFloat() < 0.0" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_math_function_with_deferred_args_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "gt(max(path($.a),path($.b)),float(0.0))", + ) + expr = "max($.a, $.b) > 0.0" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_bitwise_expression_is_lowered_by_rust_fast_path(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "gt(countOneBits(bitand(path($.a),path($.b))),int(3))", + ) + expr = "countOneBits($.a & $.b) > 3" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_when_control_structure_is_lowered_natively(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: 'eq(when(eq(path($.who),int(1)),str("bob"),str("other")),path($.state))', + ) + expr = 'when($.who == 1 => "bob", default => "other") == $.state' + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_let_control_structure_is_lowered_natively(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "gt(let(def(x,int(1)),def(y,add(var(x),int(1))),add(var(x),var(y))),int(2))", + ) + expr = "let (x = 1, y = ${x} + 1) then (${x} + ${y}) > 2" + baseline = parser_mod.AELParser.parse(expr) + fast = try_parse_ael_rust(expr) + assert fast is not None + assert _base64(fast) == _base64(baseline) + + def test_unsupported_rust_node_falls_back_to_python(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + lambda text: "unknown(path($.listBin.[0]))", + ) + expr = "$.listBin.[0] == 1" + baseline = parser_mod.AELParser.parse(expr) + fast = parser_mod.parse_ael(expr) + assert _base64(fast) == _base64(baseline) + + def test_try_parse_returns_none_when_backend_missing(self, monkeypatch) -> None: + monkeypatch.setattr( + "aerospike_sdk.ael._rust_fastpath._parse_subset_rust", + None, + ) + assert try_parse_ael_rust("$.age > 18") is None + + @pytest.mark.skipif(not rust_fastpath_available(), reason="Rust fast path extension not built") + def test_valid_repo_expression_corpus_matches_python_parser(self) -> None: + py_parser = parser_mod.AELParser() + unsupported: list[str] = [] + mismatches: list[str] = [] + valid_count = 0 + + for expr in extract_repo_expressions(_repo_root(), None): + try: + baseline = py_parser.parse(expr) + except Exception: + continue + + valid_count += 1 + fast = try_parse_ael_rust(expr) + if fast is None: + unsupported.append(expr) + continue + if _base64(fast) != _base64(baseline): + mismatches.append(expr) + + assert valid_count > 0 + assert not unsupported, f"Rust fast path unsupported valid expressions: {unsupported[:10]}" + assert not mismatches, f"Rust fast path mismatched expressions: {mismatches[:10]}" + + @pytest.mark.skipif(not rust_fastpath_available(), reason="Rust fast path extension not built") + def test_literal_parse_ael_calls_with_values_match_python_parser(self) -> None: + py_parser = parser_mod.AELParser() + unsupported: list[str] = [] + mismatches: list[str] = [] + valid_count = 0 + + for source, expr, args in _literal_parse_call_cases(): + placeholder_values = parser_mod.PlaceholderValues(*args) if args else None + try: + baseline = py_parser.parse(expr, placeholder_values) + except Exception: + continue + + valid_count += 1 + fast = try_parse_ael_rust(expr, parser_mod.PlaceholderValues(*args) if args else None) + if fast is None: + unsupported.append(source) + continue + if _base64(fast) != _base64(baseline): + mismatches.append(source) + + assert valid_count > 0 + assert not unsupported, f"Rust fast path unsupported literal parse_ael call sites: {unsupported[:10]}" + assert not mismatches, f"Rust fast path mismatched literal parse_ael call sites: {mismatches[:10]}" diff --git a/tools/ael_rust_benchmark.py b/tools/ael_rust_benchmark.py new file mode 100644 index 0000000..b5baa92 --- /dev/null +++ b/tools/ael_rust_benchmark.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import argparse +import json +import re +import statistics +import time +from pathlib import Path + +from aerospike_sdk.ael._rust_fastpath import rust_fastpath_available, try_parse_ael_rust +from aerospike_sdk.ael.exceptions import AelParseException +from aerospike_sdk.ael.parser import AELParser + + +def extract_repo_expressions(repo_root: Path, limit: int | None = None) -> list[str]: + patterns: list[str] = [] + files = list((repo_root / "tests").glob("**/*.py")) + list((repo_root / "aerospike_sdk").glob("**/*.py")) + for path in files: + try: + text = path.read_text() + except Exception: + continue + for match in re.finditer(r"\.where\((['\"])(.*?)\1\)", text, re.S): + patterns.append(match.group(2)) + for match in re.finditer(r"parse_ael(?:_with_index)?\((['\"])(.*?)\1", text, re.S): + patterns.append(match.group(2)) + seen: set[str] = set() + unique: list[str] = [] + for expr in patterns: + expr = expr.replace("\\n", " ").strip() + if expr not in seen: + seen.add(expr) + unique.append(expr) + if limit is not None and len(unique) >= limit: + break + return unique + + +def benchmark(func, expressions: list[str], rounds: int) -> dict[str, float]: + samples: list[float] = [] + for _ in range(rounds): + start = time.perf_counter() + for expr in expressions: + func(expr) + samples.append(time.perf_counter() - start) + mean_s = statistics.mean(samples) + return { + "mean_s": mean_s, + "expr_per_s": len(expressions) / mean_s if mean_s else 0.0, + "us_per_expr": (mean_s / len(expressions)) * 1_000_000 if expressions else 0.0, + } + + +def _try_rust(expr: str): + try: + return try_parse_ael_rust(expr) + except AelParseException: + return None + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Benchmark Python AEL parsing against the optional Rust fast path. " + "Build the Rust module first with: " + "maturin develop -m rust/ael_rust_ext/Cargo.toml" + ) + ) + parser.add_argument("--limit", type=int, default=500) + parser.add_argument("--rounds", type=int, default=5) + parser.add_argument("--out", type=Path, default=None) + args = parser.parse_args() + + repo_root = Path(__file__).resolve().parents[1] + expressions = extract_repo_expressions(repo_root, args.limit) + py_parser = AELParser() + + supported: list[str] = [] + mismatches: list[str] = [] + skipped: list[str] = [] + invalid_python: list[str] = [] + for expr in expressions: + rust_expr = _try_rust(expr) + if rust_expr is None: + skipped.append(expr) + continue + try: + py_expr = py_parser.parse(expr) + except Exception: + invalid_python.append(expr) + continue + if rust_expr.base64() != py_expr.base64(): + mismatches.append(expr) + continue + supported.append(expr) + + payload = { + "rust_available": rust_fastpath_available(), + "corpus_size": len(expressions), + "supported": len(supported), + "skipped": len(skipped), + "invalid_python": len(invalid_python), + "mismatches": len(mismatches), + "skipped_examples": skipped[:10], + "invalid_python_examples": invalid_python[:10], + "mismatch_examples": mismatches[:10], + } + + if supported: + payload["python"] = benchmark(py_parser.parse, supported, args.rounds) + payload["rust_fastpath"] = benchmark( + lambda expr: _try_rust(expr), + supported, + args.rounds, + ) + payload["speedup_x"] = ( + payload["python"]["mean_s"] / payload["rust_fastpath"]["mean_s"] + if payload["rust_fastpath"]["mean_s"] + else 0.0 + ) + + text = json.dumps(payload, indent=2) + if args.out: + args.out.write_text(text) + print(text) + + +if __name__ == "__main__": + main()