diff --git a/Ix/IxVM/Convert.lean b/Ix/IxVM/Convert.lean index 4917ec0f..b9cbd972 100644 --- a/Ix/IxVM/Convert.lean +++ b/Ix/IxVM/Convert.lean @@ -68,7 +68,9 @@ def convert := ⟦ match load(idxs) { ListNode.Nil => store(ListNode.Nil), ListNode.Cons(idx, rest) => - let u = load(list_lookup_u64(univs, idx)); + -- universe indices are small; walk with a field index (cheap per-step + -- field sub) instead of `list_lookup_u64`'s per-step U64 predecessor. + let u = load(list_lookup(univs, flatten_u64(idx))); store(ListNode.Cons(store(convert_univ(u)), convert_univ_idxs(rest, univs))), } } @@ -145,7 +147,9 @@ def convert := ⟦ ) -> KExpr { match load(e) { Expr.Srt(univ_idx) => - let u = load(list_lookup_u64(univs, univ_idx)); + -- field-indexed walk (see `convert_univ_idxs`): avoids the per-step + -- U64 predecessor of `list_lookup_u64` on this hot universe lookup. + let u = load(list_lookup(univs, flatten_u64(univ_idx))); store(KExprNode.Srt(store(convert_univ(u)))), Expr.Var(idx) => @@ -199,7 +203,8 @@ def convert := ⟦ convert_expr(body, sharing, ref_idxs, recur_idxs, lit_blobs, univs))), Expr.Share(idx) => - convert_expr(list_lookup_u64(sharing, idx), sharing, ref_idxs, recur_idxs, lit_blobs, univs), + let ListNode.Cons(e, _) = load(list_drop(sharing, flatten_u64(idx))); + convert_expr(e, sharing, ref_idxs, recur_idxs, lit_blobs, univs), } } diff --git a/Ix/IxVM/Ingress.lean b/Ix/IxVM/Ingress.lean index 6630cc33..5fee5918 100644 --- a/Ix/IxVM/Ingress.lean +++ b/Ix/IxVM/Ingress.lean @@ -53,17 +53,23 @@ def ingress := ⟦ bytes } - -- Compare two 32-byte addresses for equality - fn address_eq(a: Addr, b: Addr) -> G { - let [a0, a1, a2, a3, a4, a5, a6, a7, + -- Compare two 32-byte addresses for equality. + -- + -- Cold path: limb 0 already matched, compare the remaining 31 limbs. + -- Factored into its own function so it forms a separate circuit whose height + -- is only the (rare) limb-0-match rows; Aiur charges a function's full width + -- on every one of its rows, so a nested match in `address_eq` would not save + -- anything — the split must be a function boundary. + fn address_eq_tail(a: Addr, b: Addr) -> G { + let [_, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30, a31] = load(a); - let [b0, b1, b2, b3, b4, b5, b6, b7, + let [_, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20, b21, b22, b23, b24, b25, b26, b27, b28, b29, b30, b31] = load(b); - match [to_field(a0) - to_field(b0), to_field(a1) - to_field(b1), + match [to_field(a1) - to_field(b1), to_field(a2) - to_field(b2), to_field(a3) - to_field(b3), to_field(a4) - to_field(b4), to_field(a5) - to_field(b5), to_field(a6) - to_field(b6), to_field(a7) - to_field(b7), @@ -80,7 +86,20 @@ def ingress := ⟦ to_field(a28) - to_field(b28), to_field(a29) - to_field(b29), to_field(a30) - to_field(b30), to_field(a31) - to_field(b31)] { [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 1, + _ => 0, + } + } + + -- Limb-0 prefilter: a single differing limb proves inequality, and almost + -- every comparison (the primitive-dispatch gauntlet in whnf) mismatches at + -- limb 0. Hot rows reject here at narrow width; only limb-0 matches pay the + -- wide `address_eq_tail` compare. Identical result to a full 32-limb compare. + fn address_eq(a: Addr, b: Addr) -> G { + let av = load(a); + let bv = load(b); + match to_field(av[0]) - to_field(bv[0]) { + 0 => address_eq_tail(a, b), _ => 0, } } @@ -792,7 +811,9 @@ def ingress := ⟦ -- Deref Expr.Share via the constant's sharing list. fn deref_share(e: Expr, sharing: List‹&Expr›) -> Expr { match e { - Expr.Share(idx) => deref_share(load(list_lookup_u64(sharing, idx)), sharing), + Expr.Share(idx) => + let ListNode.Cons(e, _) = load(list_drop(sharing, flatten_u64(idx))); + deref_share(load(e), sharing), _ => e, } } diff --git a/Ix/IxVM/IxonDeserialize.lean b/Ix/IxVM/IxonDeserialize.lean index 59aa0b7d..4a7cf34b 100644 --- a/Ix/IxVM/IxonDeserialize.lean +++ b/Ix/IxVM/IxonDeserialize.lean @@ -101,15 +101,17 @@ def ixonDeserialize := ⟦ -- Expression deserialization -- ============================================================================ - -- App telescope: read count args, wrapping func in App nodes - fn get_app_telescope(func: Expr, stream: ByteStream, count: U64) -> (Expr, ByteStream) { + -- App telescope: read count args, wrapping func in App nodes. `func` is passed + -- by pointer (loaded only at the base case) so the recursion doesn't carry the + -- wide `Expr` union by value on every row. + fn get_app_telescope(func: &Expr, stream: ByteStream, count: U64) -> (Expr, ByteStream) { let is_zero = u64_is_zero(count); match is_zero { - 1 => (func, stream), + 1 => (load(func), stream), 0 => let (arg, s) = get_expr(stream); - let app = Expr.App(store(func), store(arg)); - get_app_telescope(app, s, relaxed_u64_pred(count)), + let app = Expr.App(func, store(arg)); + get_app_telescope(store(app), s, relaxed_u64_pred(count)), } } @@ -180,7 +182,7 @@ def ixonDeserialize := ⟦ -- App: Tag4(0x7, count) + func + args... 0x7 => let (func, s2) = get_expr(s); - get_app_telescope(func, s2, size), + get_app_telescope(store(func), s2, size), -- Lam: Tag4(0x8, count) + types... + body 0x8 => get_lam_telescope(s, size), @@ -189,17 +191,23 @@ def ixonDeserialize := ⟦ 0x9 => get_all_telescope(s, size), -- Let: Tag4(0xA, non_dep) + expr(ty) + expr(val) + expr(body) - 0xA => - let (ty, s2) = get_expr(s); - let (val, s3) = get_expr(s2); - let (body, s4) = get_expr(s3); - (Expr.Let(size, store(ty), store(val), store(body)), s4), + 0xA => get_expr_let(s, size), -- Share: Tag4(0xB, idx) 0xB => (Expr.Share(size), s), } } + -- Let arm of get_expr, split out: three recursive `get_expr` calls make it the + -- widest (and a rare) arm, so inlined it taxes every get_expr row. + fn get_expr_let(s: ByteStream, size: U64) -> (Expr, ByteStream) { + let (ty, s2) = get_expr(s); + let (val, s3) = get_expr(s2); + let (body, s4) = get_expr(s3); + (Expr.Let(size, store(ty), store(val), store(body)), s4) + } + + -- ============================================================================ -- Universe deserialization -- ============================================================================ diff --git a/Ix/IxVM/Kernel/Claim.lean b/Ix/IxVM/Kernel/Claim.lean index 27c9a0ca..7122431a 100644 --- a/Ix/IxVM/Kernel/Claim.lean +++ b/Ix/IxVM/Kernel/Claim.lean @@ -804,6 +804,43 @@ def claim := ⟦ } } + -- Pack the first 4 address bytes (LE) into a u32 key for the skip-set rbtree. + -- + -- Capped at 4 bytes because `RBTreeMap` orders keys with `u32_less_than`, a + -- 32-bit comparator gadget — a wider key (a single `G` could hold 7 bytes in + -- Goldilocks) would overflow it and corrupt the tree ordering. A 32-bit key + -- space means key collisions are possible (~N²/2^33 over N leaves), but they + -- are harmless: a collision makes `addr_in_set`'s confirming `address_eq` + -- fail, yielding a false negative (a frontier const gets rechecked) never a + -- false positive (a wrong skip). See `build_skip_set`. + fn addr_key(a: Addr) -> G { + let arr = load(a); + to_field(arr[0]) + + to_field(arr[1]) * 256 + + to_field(arr[2]) * 65536 + + to_field(arr[3]) * 16777216 + } + + -- Build an O(log N) membership set from the assumption-leaf list, keyed on + -- `addr_key`. Key collisions overwrite — sound because the only consequence is + -- a missed skip (a frontier const gets rechecked, never wrongly trusted); the + -- confirming `address_eq` in `addr_in_set` rules out false positives. + fn build_skip_set(leaves: List‹Addr›, acc: RBTreeMap‹Addr›) -> RBTreeMap‹Addr› { + match load(leaves) { + ListNode.Nil => acc, + ListNode.Cons(a, rest) => + build_skip_set(rest, rbtree_map_insert(addr_key(a), a, acc)), + } + } + + -- Membership via one rbtree lookup (cheap u32 key compares) + one confirming + -- full `address_eq`, replacing the O(N) linear `addr_in_list` scan that + -- dominated address_eq cost on sharded checks. + fn addr_in_set(target: Addr, skip_set: RBTreeMap‹Addr›) -> G { + let found = rbtree_map_lookup_or_default(addr_key(target), skip_set, store([0u8; 32])); + address_eq(found, target) + } + -- ============================================================================ -- check_all variant that skips positions whose addr is in the -- supplied assumption-leaf list. @@ -813,24 +850,27 @@ def claim := ⟦ addrs: List‹Addr›, asm_leaves: List‹Addr›) { let _ = check_canonical_block_sort(top); - check_all_skipping_iter(consts, top, addrs, asm_leaves, 0) + -- Build the skip-set once (O(N log N)) instead of an O(N) linear scan per + -- checked const. + let skip_set = build_skip_set(asm_leaves, RBTreeMap.Nil); + check_all_skipping_iter(consts, top, addrs, skip_set, 0) } fn check_all_skipping_iter(consts: List‹&KConstantInfo›, top: List‹&KConstantInfo›, addrs: List‹Addr›, - asm_leaves: List‹Addr›, + skip_set: RBTreeMap‹Addr›, pos: G) { match load(consts) { ListNode.Nil => (), ListNode.Cons(&ci, rest) => let addr = list_lookup(addrs, pos); - match addr_in_list(addr, asm_leaves) { + match addr_in_set(addr, skip_set) { 1 => - check_all_skipping_iter(rest, top, addrs, asm_leaves, pos + 1), + check_all_skipping_iter(rest, top, addrs, skip_set, pos + 1), _ => let _ = check_const(ci, pos, top, addrs); - check_all_skipping_iter(rest, top, addrs, asm_leaves, pos + 1), + check_all_skipping_iter(rest, top, addrs, skip_set, pos + 1), }, } } diff --git a/Ix/IxVM/Kernel/DefEq.lean b/Ix/IxVM/Kernel/DefEq.lean index c5dde3b6..f8678931 100644 --- a/Ix/IxVM/Kernel/DefEq.lean +++ b/Ix/IxVM/Kernel/DefEq.lean @@ -146,7 +146,7 @@ def defEq := ⟦ -- 1 iff ty is `Const(I, _) args` for non-rec 1-ctor 0-field inductive. fn is_unit_like_type(ty: KExpr, top: List‹&KConstantInfo›) -> G { - match collect_spine_simple(ty) { + match collect_spine(ty) { (head, _) => match load(head) { KExprNode.Const(idx, _) => @@ -218,36 +218,66 @@ def defEq := ⟦ } } - -- Mirror: src/ix/kernel/def_eq.rs:930-948 fn nat_succ_of. - -- `Lit(n)` n>0 → (1, Lit(n-1)). `App(Const(Nat.succ), arg)` → (1, arg). - -- Else (0, _). - fn nat_succ_of(e: KExpr, addrs: List‹Addr›) -> (G, KExpr) { + + -- Decompose a WHNF'd Nat into `base + offset` where `base` is the + -- non-offset core and `offset` a KLimbs literal. Recognizes: + -- Lit n -> (matched, 0-base, n) + -- Nat.succ e -> base/offset of e, offset+1 + -- Nat.add e (Lit m) -> base/offset of e, offset+m + -- `matched=1` iff `e` is offset-shaped (succ/add/lit). The few succ layers + -- whnf exposes are peeled, but `Nat.add base (Lit m)` is read in O(1) — so a + -- `succ^k(x)` chain (which whnf leaves as `succ(Nat.add x (Lit k-1))`) + -- decomposes to `(x, k)` in O(1) instead of k unary steps. + fn nat_offset_of(e: KExpr, addrs: List‹Addr›) -> (G, KExpr, KLimbs) { match load(e) { KExprNode.Lit(lit) => match lit { - KLiteral.Nat(limbs) => - match klimbs_is_zero(limbs) { - 1 => (0, store(KExprNode.BVar(0))), - 0 => (1, mk_nat_lit(klimbs_dec(limbs))), - }, - _ => (0, store(KExprNode.BVar(0))), + KLiteral.Nat(n) => (1, mk_nat_lit(store(ListNode.Nil)), n), + _ => (0, e, store(ListNode.Nil)), }, KExprNode.App(f, a) => match load(f) { KExprNode.Const(idx, _) => match address_eq(list_lookup(addrs, idx), nat_succ_addr()) { - 1 => (1, a), - 0 => (0, store(KExprNode.BVar(0))), + 1 => + match nat_offset_of(a, addrs) { + (_, base, o) => (1, base, klimbs_succ(o)), + }, + 0 => (0, e, store(ListNode.Nil)), }, - _ => (0, store(KExprNode.BVar(0))), + KExprNode.App(g, x) => + match load(g) { + KExprNode.Const(idx, _) => + match address_eq(list_lookup(addrs, idx), nat_add_addr()) { + 1 => + match load(a) { + KExprNode.Lit(alit) => + match alit { + KLiteral.Nat(m) => + match nat_offset_of(x, addrs) { + (1, base, o) => (1, base, klimbs_add(o, m)), + (0, _, _) => (1, x, m), + }, + _ => (0, e, store(ListNode.Nil)), + }, + _ => (0, e, store(ListNode.Nil)), + }, + 0 => (0, e, store(ListNode.Nil)), + }, + _ => (0, e, store(ListNode.Nil)), + }, + _ => (0, e, store(ListNode.Nil)), }, - _ => (0, store(KExprNode.BVar(0))), + _ => (0, e, store(ListNode.Nil)), } } - -- Mirror: src/ix/kernel/def_eq.rs:953-995 is_def_eq_nat / try_def_eq_offset. - -- Returns (matched, eq). `matched=1` iff both sides are nat-shaped (both - -- zero, both succ-headed, or both literals); `eq` is the verdict. + -- Mirror: src/ix/kernel/def_eq.rs:953-995 is_def_eq_nat / try_def_eq_offset, + -- generalized to offset form. Returns (matched, eq). Conservative: only + -- decides when both sides are offset-shaped with EQUAL offsets (then the + -- verdict is `base_a ≟ base_b`, sound because `+k` is injective); differing + -- offsets or non-offset shapes fall back (matched=0) to the generic path. + -- Collapses `succ^k(x) ≟ succ^k(x)` from k unary steps to one klimbs compare. fn try_def_eq_nat(a: KExpr, b: KExpr, types: List‹KExpr›, top: List‹&KConstantInfo›, addrs: List‹Addr›) -> (G, G) { @@ -256,13 +286,19 @@ def defEq := ⟦ match za * zb { 1 => (1, 1), 0 => - match nat_succ_of(a, addrs) { - (1, ap) => - match nat_succ_of(b, addrs) { - (1, bp) => (1, k_is_def_eq(ap, bp, types, top, addrs)), - _ => (0, 0), + match nat_offset_of(a, addrs) { + (ma, ba, oa) => + match nat_offset_of(b, addrs) { + (mb, bb, ob) => + match ma * mb { + 0 => (0, 0), + _ => + match klimbs_eq(oa, ob) { + 0 => (0, 0), + 1 => (1, k_is_def_eq(ba, bb, types, top, addrs)), + }, + }, }, - _ => (0, 0), }, } } diff --git a/Ix/IxVM/Kernel/Inductive.lean b/Ix/IxVM/Kernel/Inductive.lean index b3f89177..50b39f84 100644 --- a/Ix/IxVM/Kernel/Inductive.lean +++ b/Ix/IxVM/Kernel/Inductive.lean @@ -32,7 +32,7 @@ def inductive_check := ⟦ n_params: G, n_indices: G, n_fields: G, ind_idx: G, ind_num_lvls: G) { let body = peel_n_foralls(ctor_ty, n_params + n_fields); - let pair = collect_spine_simple(body); + let pair = collect_spine(body); match pair { (head, args) => match load(head) { @@ -59,21 +59,6 @@ def inductive_check := ⟦ } } - -- Walk a left-associative App chain, return (head, args-in-application-order). - -- Inlined to avoid Whnf import here (module ordering — Inductive precedes Whnf - -- in the dependency-first build order); identical to whnf.lean's collect_spine. - fn collect_spine_simple_go(e: KExpr, acc: List‹KExpr›) -> (KExpr, List‹KExpr›) { - match load(e) { - KExprNode.App(f, a) => - collect_spine_simple_go(f, store(ListNode.Cons(a, acc))), - _ => (e, acc), - } - } - - fn collect_spine_simple(e: KExpr) -> (KExpr, List‹KExpr›) { - collect_spine_simple_go(e, store(ListNode.Nil)) - } - -- Each `lvls[i]` must be `Param(expected_start + i)` for i in 0..count. fn assert_lvls_are_params(lvls: List‹&KLevel›, count: G, idx: G) { match count { @@ -256,7 +241,7 @@ def inductive_check := ⟦ let types2 = store(ListNode.Cons(inner_dom, types)); check_positivity_aug(inner_body, block_idxs, types2, top, addrs), _ => - match collect_spine_simple(dom_w) { + match collect_spine(dom_w) { (head, args) => match load(head) { KExprNode.Const(idx, _) => @@ -576,7 +561,7 @@ def inductive_check := ⟦ data_bvars: List‹G›) -> G { match n_fields - field_idx { 0 => - match collect_spine_simple(ty) { + match collect_spine(ty) { (_, args) => all_bvars_in_args(data_bvars, args), }, _ => @@ -853,7 +838,7 @@ def inductive_check := ⟦ let n_ihs = list_length(rec_indices); let n_binders = n_fields + n_ihs; let depth_now = minor_saved + n_binders; - let ret_pair = collect_spine_simple(ret_ty); + let ret_pair = collect_spine(ret_ty); match ret_pair { (_ret_head, ret_args) => -- Drop n_own_params from ret to expose indices. @@ -1000,7 +985,7 @@ def inductive_check := ⟦ (doms, body) => let inner_types = list_concat(list_reverse(doms), types); let body_w = whnf(body, inner_types, top, addrs); - match collect_spine_simple(body_w) { + match collect_spine(body_w) { (head, _) => match load(head) { KExprNode.Const(idx, _) => find_member_local_idx(block_member_idxs, idx, 0), @@ -1454,7 +1439,7 @@ def inductive_check := ⟦ ((inner_depth - 1) - n_params), 0); let with_minors = build_apply_minors(with_motives, n_minors, (((inner_depth - 1) - n_params) - n_motives), 0); - match collect_spine_simple(inner_body) { + match collect_spine(inner_body) { (_dh, dargs) => let idx_args = list_drop(dargs, target_n_params); let with_idx = apply_spine(with_minors, idx_args); @@ -1512,7 +1497,7 @@ def inductive_check := ⟦ let after_skip = peel_n_foralls(ty, skip); match load(after_skip) { KExprNode.Forall(major_ty, _) => - match collect_spine_simple(major_ty) { + match collect_spine(major_ty) { (head, _) => match load(head) { KExprNode.Const(idx, _) => idx, @@ -2014,7 +1999,7 @@ def inductive_check := ⟦ -> List‹(G, List‹KExpr›, List‹&KLevel›)› { match peel_leading_foralls(dom) { (_doms, body) => - match collect_spine_simple(body) { + match collect_spine(body) { (head, args) => match load(head) { KExprNode.Const(idx, occ_us) => @@ -2304,7 +2289,7 @@ def inductive_check := ⟦ let inner_depth = depth + n_xs; let motive_bvar = (inner_depth - 1) - (motive_base + mem_idx); let field_bvar = (inner_depth - 1) - (minor_saved + field_idx); - match collect_spine_simple(inner_body) { + match collect_spine(inner_body) { (_h, dom_args) => let idx_args = list_drop(dom_args, target_n_params); let motive_ref = store(KExprNode.BVar(motive_bvar)); diff --git a/Ix/IxVM/Kernel/Primitive.lean b/Ix/IxVM/Kernel/Primitive.lean index ac5c7397..1e7bffec 100644 --- a/Ix/IxVM/Kernel/Primitive.lean +++ b/Ix/IxVM/Kernel/Primitive.lean @@ -1511,7 +1511,7 @@ def primitive := ⟦ -- Returns (1, width_e, n_e) if `e` is `BitVec.ofNat W N` or -- `OfNat.ofNat (BitVec W) N _inst`. Else (0, _, _). fn bitvec_of_nat_args(e: KExpr, addrs: List‹Addr›) -> (G, KExpr, KExpr) { - match collect_spine_simple(e) { + match collect_spine(e) { (head, args) => match load(head) { KExprNode.Const(idx, _) => @@ -1534,7 +1534,7 @@ def primitive := ⟦ 1 => (0, store(KExprNode.BVar(0)), store(KExprNode.BVar(0))), 0 => let ty_arg = list_lookup(args, 0); - match collect_spine_simple(ty_arg) { + match collect_spine(ty_arg) { (ty_head, ty_args) => match load(ty_head) { KExprNode.Const(ty_idx, _) => @@ -1639,7 +1639,7 @@ def primitive := ⟦ 1 => (0, store(KExprNode.BVar(0))), 0 => let prop = list_lookup(spine, 0); - match collect_spine_simple(prop) { + match collect_spine(prop) { (lt_head, lt_args) => match load(lt_head) { KExprNode.Const(lt_idx, _) => @@ -1651,7 +1651,7 @@ def primitive := ⟦ 1 => (0, store(KExprNode.BVar(0))), 0 => let ty_arg = list_lookup(lt_args, 0); - match collect_spine_simple(ty_arg) { + match collect_spine(ty_arg) { (ty_head, ty_args) => match load(ty_head) { KExprNode.Const(ty_idx, _) => @@ -1749,7 +1749,7 @@ def primitive := ⟦ 1 => (0, store(KExprNode.BVar(0))), 0 => let val_arg = list_lookup(spine, 2); - match collect_spine_simple(val_arg) { + match collect_spine(val_arg) { (head, _) => match load(head) { KExprNode.Const(idx, _) => @@ -1771,7 +1771,7 @@ def primitive := ⟦ 1 => (0, store(KExprNode.BVar(0))), 0 => let ty_arg = list_lookup(spine, 0); - match collect_spine_simple(ty_arg) { + match collect_spine(ty_arg) { (head, _) => match load(head) { KExprNode.Const(idx, _) => diff --git a/Ix/IxVM/Kernel/Whnf.lean b/Ix/IxVM/Kernel/Whnf.lean index a80c5ddb..b2671d8e 100644 --- a/Ix/IxVM/Kernel/Whnf.lean +++ b/Ix/IxVM/Kernel/Whnf.lean @@ -46,21 +46,6 @@ ingress (slot mapping in `Primitive.lean`). -/ def whnf := ⟦ - -- ============================================================================ - -- Spine collection - -- ============================================================================ - fn collect_spine_go(e: KExpr, acc: List‹KExpr›) -> (KExpr, List‹KExpr›) { - match load(e) { - KExprNode.App(f, a) => - collect_spine_go(f, store(ListNode.Cons(a, acc))), - _ => (e, acc), - } - } - - fn collect_spine(e: KExpr) -> (KExpr, List‹KExpr›) { - collect_spine_go(e, store(ListNode.Nil)) - } - fn apply_spine(head: KExpr, spine: List‹KExpr›) -> KExpr { match load(spine) { ListNode.Nil => head, @@ -107,66 +92,174 @@ def whnf := ⟦ KExprNode.Lam(ty, body) => whnf_apply_beta(spine, head, types, top, addrs), KExprNode.Const(idx, lvls) => - let head_addr = list_lookup(addrs, idx); - let ci = load(list_lookup(top, idx)); - -- Recr / Quot heads can never match a primitive address (Nat ops, - -- Str ops, BitVec, native, decidable, proj-def all live as Ctor or - -- Defn). Skip the primitive dispatch chain for those. - match ci { - KConstantInfo.Rec(num_lvls, _, num_params, num_indices, num_motives, num_minors, rules, k_flag, _, _) => - let iota = try_iota(lvls, spine, num_lvls, num_params, num_indices, num_motives, num_minors, rules, k_flag, types, top, addrs); - match iota { - (1, reduced2) => whnf(reduced2, types, top, addrs), - (0, _) => apply_spine(head, spine), - }, - KConstantInfo.Quot(_, _, kind) => - let qiota = try_quot_iota(kind, spine, types, top, addrs); - match qiota { - (1, reduced_q) => whnf(reduced_q, types, top, addrs), - (0, _) => apply_spine(head, spine), + -- Const-head reduction (delta / iota / quot / primitive dispatch) is the + -- widest arm by far. Factored into `whnf_const_head` so `whnf_with_spine` + -- stays narrow for the ~76% of reduction steps that are App/Lam/Proj — + -- Aiur charges a function's full width on every row, so the wide dispatch + -- only taxes the Const-head rows in its own circuit. + whnf_const_head(idx, lvls, head, spine, types, top, addrs), + KExprNode.Let(_, val, body) => + let next = expr_inst1(body, val, 0); + whnf_with_spine(next, spine, types, top, addrs), + KExprNode.Proj(tidx, fidx, inner) => + -- Proj reduction (whnf the scrutinee, fin-val rewrite, ctor field pull) + -- is the next-widest arm. Factored out for the same reason as Const. + whnf_proj_head(tidx, fidx, inner, spine, types, top, addrs), + _ => apply_spine(head, spine), + } + } + + -- Proj-head WHNF dispatch, split out of `whnf_with_spine` (see its Proj arm). + fn whnf_proj_head(tidx: G, fidx: G, inner: KExpr, spine: List‹KExpr›, + types: List‹KExpr›, top: List‹&KConstantInfo›, addrs: List‹Addr›) -> KExpr { + let inner_whnf = whnf(inner, types, top, addrs); + let inner_pair = collect_spine(inner_whnf); + match inner_pair { + (inner_head, inner_args) => + -- Mirror: whnf.rs:1441-1500 try_reduce_fin_val_decidable_rec. + -- Pushes Fin.val inside Decidable.rec minors; allows iota. + let fvd_pair = try_reduce_fin_val_decidable_rec(tidx, fidx, inner_head, inner_args, addrs); + match fvd_pair { + (1, rewritten) => whnf_with_spine(rewritten, spine, types, top, addrs), + (0, _) => + match load(inner_head) { + KExprNode.Const(cidx, _) => + let cci = load(list_lookup(top, cidx)); + match cci { + KConstantInfo.Ctor(_, _, _, _, nparams, _, _) => + let field = list_lookup_or_nil(inner_args, nparams + fidx); + whnf_with_spine(field, spine, types, top, addrs), + _ => + let stuck = store(KExprNode.Proj(tidx, fidx, inner_whnf)); + apply_spine(stuck, spine), + }, + _ => + let stuck = store(KExprNode.Proj(tidx, fidx, inner_whnf)); + apply_spine(stuck, spine), }, + }, + } + } + + -- If `head` is a Nat primitive (`Nat.add` / `Nat.div` / `Nat.mod`) applied to + -- exactly (non-literal base, literal second arg), return (1, the same op in + -- canonical form) so whnf leaves it STUCK instead of delta-unfolding it. This + -- stops `Nat.add x n` from materializing succ^n(x), and `Nat.div`/`Nat.mod x n` + -- (n ≥ 2) from expanding the division algorithm — both are irreducible for a + -- symbolic base, so the compact form IS the normal form. `Nat.shiftRight x k` + -- unfolds to k nested `Nat.div _ 2`, which now stay stuck. Thresholds: `add` + -- keeps any nonzero n; `div`/`mod` keep n ≥ 2 (so `x/1 = x`, `x/0 = 0` still + -- reduce). All magnitudes stay KLimbs. `(0, _)` = "not this shape". + fn try_nat_offset_stuck(head: KExpr, spine: List‹KExpr›, types: List‹KExpr›, + top: List‹&KConstantInfo›, addrs: List‹Addr›) -> (G, KExpr) { + match load(head) { + KExprNode.Const(idx, _) => + let ca = list_lookup(addrs, idx); + let is_add = address_eq(ca, nat_add_addr()); + let is_divmod = address_eq(ca, nat_div_addr()) + address_eq(ca, nat_mod_addr()); + match is_add + is_divmod { + 0 => (0, store(KExprNode.BVar(0))), _ => - let nat_pair = try_nat_dispatch(head_addr, spine, types, top, addrs); - match nat_pair { - (1, reduced) => whnf(reduced, types, top, addrs), + match list_length(spine) { + 2 => + let a0_w = whnf(list_lookup(spine, 0), types, top, addrs); + let a1_w = whnf(list_lookup(spine, 1), types, top, addrs); + match try_extract_nat(a1_w, addrs) { + (0, _) => (0, store(KExprNode.BVar(0))), + (1, n) => + -- reject n=0 (all ops) and n=1 (div/mod only). + let bad = klimbs_is_zero(n) + is_divmod * klimbs_is_zero(klimbs_dec(n)); + match bad { + 0 => + match try_extract_nat(a0_w, addrs) { + (1, _) => (0, store(KExprNode.BVar(0))), + (0, _) => + (1, store(KExprNode.App( + store(KExprNode.App(head, a0_w)), + mk_nat_lit(n)))), + }, + _ => (0, store(KExprNode.BVar(0))), + }, + }, + _ => (0, store(KExprNode.BVar(0))), + }, + }, + _ => (0, store(KExprNode.BVar(0))), + } + } + + -- Const-head WHNF dispatch, split out of `whnf_with_spine` (see its Const arm). + -- `head` is the original `Const(idx, lvls)` KExpr, passed for the stuck + -- `apply_spine(head, spine)` fallbacks. + fn whnf_const_head(idx: G, lvls: List‹&KLevel›, head: KExpr, spine: List‹KExpr›, + types: List‹KExpr›, top: List‹&KConstantInfo›, addrs: List‹Addr›) -> KExpr { + let head_addr = list_lookup(addrs, idx); + let ci = load(list_lookup(top, idx)); + -- Recr / Quot heads can never match a primitive address (Nat ops, + -- Str ops, BitVec, native, decidable, proj-def all live as Ctor or + -- Defn). Skip the primitive dispatch chain for those. + match ci { + KConstantInfo.Rec(num_lvls, _, num_params, num_indices, num_motives, num_minors, rules, k_flag, _, _) => + let iota = try_iota(lvls, spine, num_lvls, num_params, num_indices, num_motives, num_minors, rules, k_flag, types, top, addrs); + match iota { + (1, reduced2) => whnf(reduced2, types, top, addrs), + (0, _) => apply_spine(head, spine), + }, + KConstantInfo.Quot(_, _, kind) => + let qiota = try_quot_iota(kind, spine, types, top, addrs); + match qiota { + (1, reduced_q) => whnf(reduced_q, types, top, addrs), + (0, _) => apply_spine(head, spine), + }, + _ => + let nat_pair = try_nat_dispatch(head_addr, spine, types, top, addrs); + match nat_pair { + (1, reduced) => whnf(reduced, types, top, addrs), + (0, _) => + let str_pair = try_str_dispatch(head_addr, spine, addrs); + match str_pair { + (1, reduced_s) => whnf(reduced_s, types, top, addrs), (0, _) => - let str_pair = try_str_dispatch(head_addr, spine, addrs); - match str_pair { - (1, reduced_s) => whnf(reduced_s, types, top, addrs), + let bv_pair = try_bitvec_dispatch(head_addr, spine, types, top, addrs); + match bv_pair { + (1, reduced_b) => whnf(reduced_b, types, top, addrs), (0, _) => - let bv_pair = try_bitvec_dispatch(head_addr, spine, types, top, addrs); - match bv_pair { - (1, reduced_b) => whnf(reduced_b, types, top, addrs), + let nat_pair2 = try_reduce_native(head_addr, spine, types, top, addrs); + match nat_pair2 { + (1, reduced_n) => whnf(reduced_n, types, top, addrs), (0, _) => - let nat_pair2 = try_reduce_native(head_addr, spine, types, top, addrs); - match nat_pair2 { - (1, reduced_n) => whnf(reduced_n, types, top, addrs), + let dec_pair = try_reduce_decidable(head_addr, idx, lvls, spine, types, top, addrs); + match dec_pair { + (1, reduced_d) => whnf(reduced_d, types, top, addrs), (0, _) => - let dec_pair = try_reduce_decidable(head_addr, idx, lvls, spine, types, top, addrs); - match dec_pair { - (1, reduced_d) => whnf(reduced_d, types, top, addrs), + let proj_def_pair = try_reduce_projection_definition(idx, spine, top); + match proj_def_pair { + (1, reduced_pd) => whnf(reduced_pd, types, top, addrs), (0, _) => - let proj_def_pair = try_reduce_projection_definition(idx, spine, top); - match proj_def_pair { - (1, reduced_pd) => whnf(reduced_pd, types, top, addrs), - (0, _) => - -- Mirror src/ix/kernel/whnf.rs:756-774 - -- (`delta_unfold_one`): unfold any Defn - -- regardless of `ReducibilityHints`. The - -- hint is consulted by lazy-delta's - -- `delta_rank` for def-eq priority, not - -- as a gate on plain whnf delta. Without - -- unfolding here, ctor field types - -- written via opaque defs (e.g. - -- `constType (n α) (n α)`) stay stuck - -- and `check_positivity_aug` misclassifies. - match ci { - KConstantInfo.Defn(_, _, value, _, _) => + -- Mirror src/ix/kernel/whnf.rs:756-774 + -- (`delta_unfold_one`): unfold any Defn + -- regardless of `ReducibilityHints`. The + -- hint is consulted by lazy-delta's + -- `delta_rank` for def-eq priority, not + -- as a gate on plain whnf delta. Without + -- unfolding here, ctor field types + -- written via opaque defs (e.g. + -- `constType (n α) (n α)`) stay stuck + -- and `check_positivity_aug` misclassifies. + match ci { + KConstantInfo.Defn(_, _, value, _, _) => + -- Keep `Nat.add base (Lit n)` (symbolic base) + -- stuck as a compact offset instead of + -- delta-unfolding into a succ^n tower. Pairs + -- with offset-aware def-eq. + match try_nat_offset_stuck(head, spine, types, top, addrs) { + (1, stuck) => stuck, + (0, _) => let body = expr_inst_levels(value, lvls); whnf_with_spine(body, spine, types, top, addrs), - KConstantInfo.Thm(_, _, _) => apply_spine(head, spine), - _ => apply_spine(head, spine), }, + KConstantInfo.Thm(_, _, _) => apply_spine(head, spine), + _ => apply_spine(head, spine), }, }, }, @@ -174,38 +267,6 @@ def whnf := ⟦ }, }, }, - KExprNode.Let(_, val, body) => - let next = expr_inst1(body, val, 0); - whnf_with_spine(next, spine, types, top, addrs), - KExprNode.Proj(tidx, fidx, inner) => - let inner_whnf = whnf(inner, types, top, addrs); - let inner_pair = collect_spine(inner_whnf); - match inner_pair { - (inner_head, inner_args) => - -- Mirror: whnf.rs:1441-1500 try_reduce_fin_val_decidable_rec. - -- Pushes Fin.val inside Decidable.rec minors; allows iota. - let fvd_pair = try_reduce_fin_val_decidable_rec(tidx, fidx, inner_head, inner_args, addrs); - match fvd_pair { - (1, rewritten) => whnf_with_spine(rewritten, spine, types, top, addrs), - (0, _) => - match load(inner_head) { - KExprNode.Const(cidx, _) => - let cci = load(list_lookup(top, cidx)); - match cci { - KConstantInfo.Ctor(_, _, _, _, nparams, _, _) => - let field = list_lookup_or_nil(inner_args, nparams + fidx); - whnf_with_spine(field, spine, types, top, addrs), - _ => - let stuck = store(KExprNode.Proj(tidx, fidx, inner_whnf)); - apply_spine(stuck, spine), - }, - _ => - let stuck = store(KExprNode.Proj(tidx, fidx, inner_whnf)); - apply_spine(stuck, spine), - }, - }, - }, - _ => apply_spine(head, spine), } } @@ -451,10 +512,29 @@ def whnf := ⟦ 1 => let raw_base = list_lookup(spine, base_idx); let base_w = whnf(raw_base, types, top, addrs); + let post = list_drop(spine, major_idx + 1); match try_extract_nat(base_w, addrs) { - (0, _) => (0, store(KExprNode.BVar(0))), + -- Symbolic base: collapse `Nat.rec base (succ-step) (Lit n)` + -- to the offset primitive `Nat.add base (Lit n)` rather than + -- materializing succ^n(base) via n iota steps. This keeps the + -- value in the same compact `base + n` form a literal already + -- has, so def-eq converges instead of descending n unary + -- succ layers (the UTF-8 `x + 0xC0` pathology). + (0, _) => + match klimbs_is_zero(n_klimbs) { + 1 => (1, apply_spine(base_w, post)), + 0 => + match find_addr_idx_safe(nat_add_addr(), addrs, 0) { + (0, _) => (0, store(KExprNode.BVar(0))), + (1, add_idx) => + let add_const = store(KExprNode.Const(add_idx, store(ListNode.Nil))); + let off = store(KExprNode.App( + store(KExprNode.App(add_const, base_w)), + mk_nat_lit(n_klimbs))); + (1, apply_spine(off, post)), + }, + }, (1, b_klimbs) => - let post = list_drop(spine, major_idx + 1); (1, apply_spine(mk_nat_lit(klimbs_add(b_klimbs, n_klimbs)), post)), }, }, diff --git a/Ix/IxVM/KernelTypes.lean b/Ix/IxVM/KernelTypes.lean index 950a8f2c..b4a57b34 100644 --- a/Ix/IxVM/KernelTypes.lean +++ b/Ix/IxVM/KernelTypes.lean @@ -50,6 +50,24 @@ def kernelTypes := ⟦ type KExpr = &KExprNode + -- Collect an application spine: peel `App(f, a)` layers, returning the head + -- and the args in application order. The single shared definition for every + -- kernel caller (whnf, primitive, inductive_check, def_eq) — defined here in + -- `kernelTypes` so it precedes them all in the merge. + -- + -- Non-tail (no accumulator): keyed on `e` alone, so Aiur memoization dedups + -- shared sub-spines across reductions. An accumulator would thread `acc` + -- through the memo key and block that sharing (tail recursion buys nothing in + -- Aiur — stack depth is free). `list_snoc` keeps order and is itself memoized. + fn collect_spine(e: KExpr) -> (KExpr, List‹KExpr›) { + match load(e) { + KExprNode.App(f, a) => + let (head, args) = collect_spine(f); + (head, list_snoc(args, a)), + _ => (e, store(ListNode.Nil)), + } + } + -- ============================================================================ -- Values (NbE semantic domain) -- ============================================================================