From ad3760b2cbefdd750556050ec58f9261db7bfdbd Mon Sep 17 00:00:00 2001 From: stringhandler Date: Wed, 1 Jul 2026 13:38:12 +0200 Subject: [PATCH] feat: add enum support with match desugaring Add enum declarations and enum match expressions to SimplicityHL. Enums are registered as u8 type aliases; each variant has a u8 discriminant Enum match expressions desugar to binary bool-match chains that dispatch on discriminant equality, ensuring invalid discriminants fail via panic rather than silently executing any arm. --- examples/last_will.inherit.wit | 10 +- examples/last_will.simf | 22 +- src/ast.rs | 580 ++++++++++++++++++++++++++++++++- src/driver/mod.rs | 5 +- src/driver/resolve_order.rs | 4 +- src/lexer.rs | 42 ++- src/lib.rs | 249 +++++++++++++- src/named.rs | 62 ++++ src/parse.rs | 562 +++++++++++++++++++++++++------- src/value.rs | 28 ++ src/witness.rs | 60 +++- test-data/last_will.json | 2 +- tests/core_tracker.rs | 31 ++ 13 files changed, 1504 insertions(+), 153 deletions(-) diff --git a/examples/last_will.inherit.wit b/examples/last_will.inherit.wit index 16752030..88b46ab7 100644 --- a/examples/last_will.inherit.wit +++ b/examples/last_will.inherit.wit @@ -1,6 +1,10 @@ { - "INHERIT_OR_NOT": { - "value": "Left(0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f)", - "type": "Either>" + "ACTION": { + "value": "1", + "type": "u8" + }, + "INHERITOR_SIG": { + "value": "0x755201bb62b0a8b8d18fd12fc02951ea3998ba42bfc6664daaf8a0d2298cad43cdc21358c7c82f37654275dc2fea8c858adbe97bac92828b498a5a237004db6f", + "type": "Signature" } } diff --git a/examples/last_will.simf b/examples/last_will.simf index 9790a1cf..aab2bf73 100644 --- a/examples/last_will.simf +++ b/examples/last_will.simf @@ -40,12 +40,22 @@ fn refresh_spend(hot_sig: Signature) { recursive_covenant(); } +enum Action { + Inherit=1, + ColdSpend =2, + HotSpend =3, +} + fn main() { - match witness::INHERIT_OR_NOT { - Left(inheritor_sig: Signature) => inherit_spend(inheritor_sig), - Right(cold_or_hot: Either) => match cold_or_hot { - Left(cold_sig: Signature) => cold_spend(cold_sig), - Right(hot_sig: Signature) => refresh_spend(hot_sig), - }, + match witness::ACTION { + Action::Inherit => { + let inheritor_sig: Signature = witness::INHERITOR_SIG; + inherit_spend(inheritor_sig)} , + Action::ColdSpend => { + let cold_sig: Signature = witness::COLD_SIG; + cold_spend(cold_sig) }, + Action::HotSpend => { + let hot_sig: Signature = witness::HOT_SIG; + refresh_spend(hot_sig) }, } } diff --git a/src/ast.rs b/src/ast.rs index f7e64549..7e5faf5d 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,5 +1,5 @@ use std::collections::hash_map::Entry; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::sync::Arc; @@ -18,7 +18,7 @@ use crate::str::{AliasName, FunctionName, Identifier, ModuleName, SymbolName, Wi use crate::types::{ AliasedType, ResolvedType, StructuralType, TypeConstructible, TypeDeconstructible, UIntType, }; -use crate::value::{UIntValue, Value}; +use crate::value::{UIntValue, Value, ValueConstructible}; use crate::witness::{Parameters, WitnessTypes}; use crate::{impl_eq_hash, parse}; @@ -553,11 +553,41 @@ impl_jet_hinter!(CoreJetHinter, Core); #[derive(Clone, Debug, Eq, PartialEq, Default)] struct ModuleScope { aliases: HashMap, + /// Enum definitions declared in this module, keyed by enum name. + enums: HashMap, functions: HashMap, /// Nested inling `mod` blocks, each becoming a child scope. submodules: HashMap, } +/// A single enum variant after analysis: its name and u8 discriminant, without source span. +#[derive(Clone, Debug, Eq, PartialEq)] +struct ResolvedEnumVariant { + name: Identifier, + discriminant: u8, +} + +/// The resolved definition of an enum as stored in [`Scope`]: +/// a list of [`ResolvedEnumVariant`]s in declaration order. +#[derive(Clone, Debug, Eq, PartialEq)] +struct EnumBinding { + variants: Arc<[ResolvedEnumVariant]>, +} + +impl EnumBinding { + fn new(variants: Arc<[ResolvedEnumVariant]>) -> Self { + Self { variants } + } + + fn variants(&self) -> &[ResolvedEnumVariant] { + &self.variants + } + + fn contains_variant(&self, name: &Identifier) -> bool { + self.variants.iter().any(|v| &v.name == name) + } +} + /// Scope for generating the abstract syntax tree. /// /// The scope is used for: @@ -580,6 +610,8 @@ struct Scope { is_main: bool, call_tracker: CallTracker, jet_hinter: Box, + /// Monotonic counter backing [`Scope::fresh_identifier`]. + fresh_counter: usize, } impl Default for Scope { @@ -602,9 +634,20 @@ impl Scope { is_main: false, call_tracker: CallTracker::default(), jet_hinter, + fresh_counter: 0, } } + /// Generate a compiler-internal identifier that cannot collide with a user variable. + /// + /// The returned name begins with a digit, which the lexer never produces for an + /// identifier, and embeds a monotonic counter so repeated calls are unique. + fn fresh_identifier(&mut self, tag: &str) -> Identifier { + let n = self.fresh_counter; + self.fresh_counter += 1; + Identifier::from_str_unchecked(&format!("{n}_{tag}")) + } + pub fn is_outside_function(&self) -> bool { self.variables.is_empty() } @@ -935,6 +978,32 @@ impl Scope { Ok(()) } + pub fn insert_enum( + &mut self, + name: AliasName, + visibility: Visibility, + variants: Arc<[ResolvedEnumVariant]>, + ) -> Result<(), Error> { + if self.current_module().enums.contains_key(&name) + || self.current_module().aliases.contains_key(&name) + { + return Err(Error::RedefinedAlias { name }); + } + // An enum is also a `u8` type alias, so its name resolves as a type. + let resolved = self.resolve(&AliasedType::from(UIntType::U8))?; + self.current_module_mut() + .aliases + .insert(name.clone(), (resolved, visibility)); + self.current_module_mut() + .enums + .insert(name, EnumBinding::new(variants)); + Ok(()) + } + + pub fn get_enum(&self, name: &AliasName) -> Option<&EnumBinding> { + self.current_module().enums.get(name) + } + /// Insert a parameter into the global map. /// /// ## Errors @@ -1134,6 +1203,53 @@ impl AbstractSyntaxTree for Item { Ok(Self::Module(analyzed_children)) } parse::Item::Ignored => Ok(Self::Ignored), + parse::Item::EnumDeclaration(decl) => { + let n = decl.variants().len(); + if n < 2 { + return Err(Error::Grammar { + msg: format!("enum '{}' must have at least 2 variants", decl.name()), + }) + .with_span(decl); + } + let mut sorted: Vec<&parse::EnumVariant> = decl.variants().iter().collect(); + sorted.sort_by_key(|v| v.discriminant()); + for w in sorted.windows(2) { + if w[0].discriminant() == w[1].discriminant() { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate discriminant {}", + decl.name(), + w[0].discriminant() + ), + }) + .with_span(decl); + } + } + let mut seen_names = HashSet::new(); + for v in decl.variants() { + if !seen_names.insert(v.name()) { + return Err(Error::Grammar { + msg: format!( + "enum '{}' has duplicate variant name '{}'", + decl.name(), + v.name() + ), + }) + .with_span(decl); + } + } + let variants: Arc<[ResolvedEnumVariant]> = sorted + .iter() + .map(|v| ResolvedEnumVariant { + name: v.name().clone(), + discriminant: v.discriminant(), + }) + .collect(); + scope + .insert_enum(decl.name().clone(), decl.visibility().clone(), variants) + .with_span(decl)?; + Ok(Self::TypeAlias) + } } } } @@ -1447,6 +1563,75 @@ impl AbstractSyntaxTree for SingleExpression { parse::SingleExpressionInner::Match(match_) => { Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)? } + parse::SingleExpressionInner::EnumMatch(enum_match) => { + let arms = enum_match.arms(); + let span = *enum_match.span(); + if arms.is_empty() { + return Err(Error::Grammar { + msg: "enum match has no arms".to_string(), + }) + .with_span(span); + } + let enum_name = match arms[0].pattern() { + MatchPattern::EnumVariant(name, _) => name.clone(), + _ => unreachable!("EnumMatch arms have EnumVariant patterns"), + }; + let binding = scope + .get_enum(&enum_name) + .ok_or_else(|| Error::UndefinedAlias { + name: enum_name.clone(), + }) + .with_span(span)?; + let mut arm_map: HashMap<&Identifier, &parse::Expression> = HashMap::new(); + for arm in arms { + let MatchPattern::EnumVariant(arm_enum_name, variant) = arm.pattern() else { + unreachable!("EnumMatch arms have EnumVariant patterns") + }; + if arm_enum_name != &enum_name { + return Err(Error::Grammar { + msg: format!( + "all match arms must use the same enum; expected '{}', found '{}'", + enum_name, arm_enum_name + ), + }) + .with_span(span); + } + if !binding.contains_variant(variant) { + return Err(Error::Grammar { + msg: format!( + "variant '{}' is not defined in enum '{}'", + variant, enum_name + ), + }) + .with_span(span); + } + if arm_map.insert(variant, arm.expression()).is_some() { + return Err(Error::Grammar { + msg: format!("duplicate arm for variant '{}'", variant), + }) + .with_span(span); + } + } + if arm_map.len() != binding.variants().len() { + return Err(Error::Grammar { + msg: format!( + "enum match on '{}' must cover all {} variants", + enum_name, + binding.variants().len() + ), + }) + .with_span(span); + } + let ordered_arms: Vec<(&parse::Expression, u8)> = binding + .variants() + .iter() + .map(|v| (arm_map[&v.name], v.discriminant)) + .collect(); + let u8_ty = ResolvedType::from(UIntType::U8); + let scrutinee = + Expression::analyze(enum_match.scrutinee(), &u8_ty, scope).map(Arc::new)?; + desugar_enum_arms_u8(&ordered_arms, scrutinee, ty, scope, span)? + } }; Ok(Self { @@ -1808,6 +1993,162 @@ impl AbstractSyntaxTree for Match { } } +/// Build an `Expression` wrapping a single inner expression of the given type. +/// +/// Both the `SingleExpression` and its enclosing `Expression` carry the same `ty`/`span`, +/// so this helper keeps the two in sync and removes the repetitive nested-struct +/// boilerplate in the enum-match desugaring below. +fn single_expr(inner: SingleExpressionInner, ty: &ResolvedType, span: Span) -> Expression { + Expression { + inner: ExpressionInner::Single(SingleExpression { + inner, + ty: ty.clone(), + span, + }), + ty: ty.clone(), + span, + } +} + +/// Desugar an N-arm enum match (u8 discriminant) into an equality-comparison chain. +fn desugar_enum_arms_u8( + arms: &[(&parse::Expression, u8)], + scrutinee: Arc, + expected_ty: &ResolvedType, + scope: &mut Scope, + span: Span, +) -> Result { + debug_assert!(arms.len() >= 2); + + let u8_ty = ResolvedType::from(UIntType::U8); + + // Resolve the equality jet from the active target's jet set rather than hardcoding + // Elements, so the desugaring is correct for every jet family (Core, Elements, ...). + let eq8_jet = scope + .jet_hinter + .parse_jet("eq_8") + .ok_or(Error::Internal { + msg: "target jet set does not provide the `eq_8` jet required to desugar an enum match" + .to_string(), + }) + .with_span(span)?; + let eq8_call = CallName::Jet(eq8_jet); + + // Bind the scrutinee to a compiler-generated variable so it is evaluated once and + // referenced N times in the comparison chain, avoiding witness-reuse errors. The name + // cannot clash with a user variable (see [`Scope::fresh_identifier`]). + let disc_ident = scope.fresh_identifier("enum_disc"); + + scope.enter_block(); + scope.insert_variable(disc_ident.clone(), u8_ty.clone()); + + let analyzed_arms = arms + .iter() + .map(|(e, disc)| { + scope.enter_block(); + let result = + Expression::analyze(e, expected_ty, scope).map(|expr| (Arc::new(expr), *disc)); + scope.exit_block(); + result + }) + .collect::, _>>(); + // Balance the `enter_block` above on every path, including the error path, before `?`. + scope.exit_block(); + let analyzed_arms = analyzed_arms?; + + let chain = build_u8_chain( + &disc_ident, + &analyzed_arms, + &eq8_call, + expected_ty, + &u8_ty, + span, + ); + + // Wrap in a block: `{ let : u8 = scrutinee; }`. + let chain_expr = Arc::new(single_expr(chain, expected_ty, span)); + let assign_stmt = Statement::Assignment(Assignment { + pattern: Pattern::Identifier(disc_ident), + expression: (*scrutinee).clone(), + span, + }); + Ok(SingleExpressionInner::Expression(Arc::new(Expression { + inner: ExpressionInner::Block(Arc::from([assign_stmt]), Some(chain_expr)), + ty: expected_ty.clone(), + span, + }))) +} + +/// Build a nested boolean-`Match` chain that dispatches on the u8 discriminant. +/// +/// Every variant, including the last, is guarded by an equality comparison. A `panic!()` +/// on the final `false` branch ensures that any undeclared discriminant value fails the +/// script rather than silently executing the last arm: +/// +/// `if eq(disc, d[0]) { arms[0] } else if eq(disc, d[1]) { arms[1] } ... else { panic!() }` +fn build_u8_chain( + disc_ident: &Identifier, + arms: &[(Arc, u8)], + eq8_call: &CallName, + expected_ty: &ResolvedType, + u8_ty: &ResolvedType, + span: Span, +) -> SingleExpressionInner { + debug_assert!(!arms.is_empty()); // guaranteed by the `>= 2` check in the caller + + let (arm_expr, discriminant) = &arms[0]; + + let bool_ty = ResolvedType::boolean(); + let disc_var = single_expr( + SingleExpressionInner::Variable(disc_ident.clone()), + u8_ty, + span, + ); + let const_expr = single_expr( + SingleExpressionInner::Constant(Value::u8(*discriminant)), + u8_ty, + span, + ); + let eq8_expr = Arc::new(single_expr( + SingleExpressionInner::Call(Call { + name: eq8_call.clone(), + args: Arc::from([disc_var, const_expr]), + span, + }), + &bool_ty, + span, + )); + + let false_branch = if arms.len() == 1 { + // Last arm: an undeclared discriminant must not silently execute any arm. + Arc::new(single_expr( + SingleExpressionInner::Call(Call { + name: CallName::Panic, + args: Arc::from([]), + span, + }), + expected_ty, + span, + )) + } else { + let rest_inner = build_u8_chain(disc_ident, &arms[1..], eq8_call, expected_ty, u8_ty, span); + Arc::new(single_expr(rest_inner, expected_ty, span)) + }; + + SingleExpressionInner::Match(Match { + scrutinee: eq8_expr, + left: MatchArm { + pattern: MatchPattern::False, + expression: false_branch, + }, + right: MatchArm { + pattern: MatchPattern::True, + expression: arm_expr.clone(), + }, + span, + }) +} + impl AsRef for Assignment { fn as_ref(&self) -> &Span { &self.span @@ -1838,6 +2179,241 @@ impl AsRef for Match { } } +#[cfg(test)] +mod enum_tests { + use super::{ElementsJetHinter, Program}; + use crate::driver::tests::setup_graph; + use crate::error::ErrorCollector; + + fn analyze(src: &str) -> Result<(), String> { + let (graph, _ids, _dir) = setup_graph(vec![("main.simf", src)]); + let mut handler = ErrorCollector::new(); + let driver_prog = graph + .linearize_and_build(&mut handler) + .unwrap() + .expect("driver build should succeed"); + Program::analyze(&driver_prog, Box::new(ElementsJetHinter::new())) + .map(|_| ()) + .map_err(|e| e.to_string()) + } + + #[test] + fn enum_declaration_registers_type_alias() { + let result = analyze( + "enum Color { Red = 1, Green = 2 } + fn main() { let _x: Color = witness::C; }", + ); + assert!( + result.is_ok(), + "enum should register as type alias: {result:?}" + ); + } + + #[test] + fn enum_match_on_function_return() { + let result = analyze( + "enum Dir { Left = 1, Right = 2 } + fn wrap(d: Dir) -> Dir { d } + fn main() { + match wrap(witness::D) { + Dir::Left => assert!(jet::eq_32(0, 0)), + Dir::Right => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "enum match on function return should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_2_variants_desugars() { + let result = analyze( + "enum Coin { Heads = 1, Tails = 2 } + fn main() { + match witness::C { + Coin::Heads => assert!(jet::eq_32(0, 0)), + Coin::Tails => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "2-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_3_variants_desugars() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "3-variant enum match should analyze: {result:?}" + ); + } + + #[test] + fn enum_match_arms_sorted_by_discriminant() { + // Arms in reverse discriminant order should still compile correctly. + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::C => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!( + result.is_ok(), + "arms in any order should compile: {result:?}" + ); + } + + #[test] + fn enum_too_few_variants_is_error() { + let result = analyze("enum Bad { Only = 1 } fn main() {}"); + assert!(result.is_err(), "single-variant enum should error"); + assert!( + result.unwrap_err().contains("at least 2 variants"), + "expected 'at least 2 variants' in error" + ); + } + + #[test] + fn enum_duplicate_discriminant_is_error() { + let result = analyze("enum Bad { A = 1, B = 1 } fn main() {}"); + assert!(result.is_err(), "duplicate discriminant should error"); + assert!( + result.unwrap_err().contains("duplicate discriminant"), + "expected 'duplicate discriminant' in error" + ); + } + + #[test] + fn enum_duplicate_variant_name_is_error() { + let result = analyze("enum Bad { A = 1, A = 2 } fn main() {}"); + assert!(result.is_err(), "duplicate variant name should error"); + assert!( + result.unwrap_err().contains("duplicate variant name"), + "expected 'duplicate variant name' in error" + ); + } + + #[test] + fn enum_duplicate_name_is_error() { + // Duplicate detection happens during semantic analysis (`Program::analyze`), + // not during flattening, so go through the `analyze` helper like the sibling + // duplicate-variant/discriminant tests. + let result = analyze( + "enum Color { Red = 1, Green = 2 } + enum Color { Blue = 1, Yellow = 2 } + fn main() {}", + ); + assert!( + result.is_err(), + "duplicate enum name should cause build failure" + ); + } + + #[test] + fn enum_match_missing_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "missing arm should error"); + assert!( + result.unwrap_err().contains("must cover all"), + "expected 'must cover all' in error" + ); + } + + #[test] + fn enum_match_unknown_variant_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::X => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "unknown variant should error"); + assert!( + result.unwrap_err().contains("not defined in enum"), + "expected 'not defined in enum' in error" + ); + } + + #[test] + fn enum_match_duplicate_arm_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Path::A => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "duplicate arm should error"); + assert!( + result.unwrap_err().contains("duplicate arm"), + "expected 'duplicate arm' in error" + ); + } + + #[test] + fn enum_match_mixed_enum_names_is_error() { + let result = analyze( + "enum Path { A = 1, B = 2 } + enum Other { A = 1, B = 2 } + fn main() { + match witness::P { + Path::A => assert!(jet::eq_32(0, 0)), + Other::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "mixed enum names should error"); + assert!( + result.unwrap_err().contains("same enum"), + "expected 'same enum' in error" + ); + } + + #[test] + fn enum_match_undefined_enum_is_error() { + let result = analyze( + "fn main() { + match witness::P { + Unknown::A => assert!(jet::eq_32(0, 0)), + Unknown::B => assert!(jet::eq_32(0, 0)), + } + }", + ); + assert!(result.is_err(), "undefined enum should error"); + } +} + #[cfg(test)] mod scope_resolution_tests { use super::{ElementsJetHinter, Program}; diff --git a/src/driver/mod.rs b/src/driver/mod.rs index 4273f8d4..6b6572fb 100644 --- a/src/driver/mod.rs +++ b/src/driver/mod.rs @@ -388,7 +388,10 @@ impl<'a> ImportContext<'a> { } // These items carry no import information at this stage and can be safely skipped. - parse::Item::TypeAlias(_) | parse::Item::Function(_) | parse::Item::Ignored => {} + parse::Item::TypeAlias(_) + | parse::Item::Function(_) + | parse::Item::EnumDeclaration(_) + | parse::Item::Ignored => {} } } diff --git a/src/driver/resolve_order.rs b/src/driver/resolve_order.rs index ef7e41b9..94990208 100644 --- a/src/driver/resolve_order.rs +++ b/src/driver/resolve_order.rs @@ -77,7 +77,9 @@ impl DependencyGraph { &items, ))) } - parse::Item::TypeAlias(_) | parse::Item::Function(_) => Some(item.clone()), + parse::Item::TypeAlias(_) + | parse::Item::Function(_) + | parse::Item::EnumDeclaration(_) => Some(item.clone()), parse::Item::Ignored => None, } } diff --git a/src/lexer.rs b/src/lexer.rs index b98ef3ea..5c251019 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -22,6 +22,7 @@ pub enum Token<'src> { Mod, Const, Match, + Enum, Crate, // Control symbols @@ -81,6 +82,7 @@ impl<'src> fmt::Display for Token<'src> { Token::Mod => write!(f, "mod"), Token::Const => write!(f, "const"), Token::Match => write!(f, "match"), + Token::Enum => write!(f, "enum"), Token::Crate => write!(f, "{}", CRATE_STR), Token::Arrow => write!(f, "->"), @@ -157,6 +159,7 @@ pub fn lexer<'src>( "mod" => Token::Mod, "const" => Token::Const, "match" => Token::Match, + "enum" => Token::Enum, CRATE_STR => Token::Crate, "true" => Token::Bool(true), "false" => Token::Bool(false), @@ -262,7 +265,8 @@ pub fn lex(file_id: usize, input: &str) -> (Option>, Vec Result { self.satisfy_with_env(witness_values, None) } - /// Satisfy the SimplicityHL program with the given `witness_values`. - /// If `env` is `None`, the program is not pruned, otherwise it is pruned with the given environment. + /// Satisfy the program, pruning with `env` if it is `Some`. + /// + /// Like [`satisfy`](Self::satisfy), all declared witnesses must be present. Equivalent to + /// [`satisfy_with_env_fill_missing`](Self::satisfy_with_env_fill_missing) with + /// `fill_missing = false`. /// /// ## Errors /// - /// - Witness values have a different type than declared in the SimplicityHL program. - /// - There are missing witness values. + /// - A witness value has the wrong type, or is missing. pub fn satisfy_with_env( &self, witness_values: WitnessValues, env: Option<&ElementsEnv>>, + ) -> Result { + self.satisfy_with_env_fill_missing(witness_values, env, false) + } + + /// Satisfy the program, pruning with `env` if it is `Some`. + /// + /// If `fill_missing` is `true`, witnesses absent from `witness_values` are zero-filled, so a + /// multi-branch program can be satisfied with only the executed branch's witnesses; when + /// pruning, a zero-filled witness that survives on a live branch is still rejected. If + /// `false`, a missing witness is an error. + /// + /// ## Errors + /// + /// - A witness value has the wrong type. + /// - A witness is missing (`fill_missing = false`), or a zero-filled witness survives pruning + /// on the executed branch (`fill_missing = true`). + pub fn satisfy_with_env_fill_missing( + &self, + witness_values: WitnessValues, + env: Option<&ElementsEnv>>, + fill_missing: bool, ) -> Result { witness_values .is_consistent(&self.witness_types) .map_err(|e| e.to_string())?; - let mut simplicity_redeem = named::populate_witnesses(&self.simplicity, witness_values)?; - if let Some(env) = env { - simplicity_redeem = simplicity_redeem.prune(env).map_err(|e| e.to_string())?; - } + let (witness_values, zero_filled) = if fill_missing { + witness_values.fill_missing(&self.witness_types) + } else { + (witness_values, std::collections::HashSet::new()) + }; + let simplicity_redeem = named::populate_witnesses(&self.simplicity, witness_values)?; + let simplicity_redeem = if let Some(env) = env { + let pruned = simplicity_redeem.prune(env).map_err(|e| e.to_string())?; + if !zero_filled.is_empty() { + named::check_surviving_witnesses(&self.simplicity, &pruned, &zero_filled)?; + } + pruned + } else { + simplicity_redeem + }; Ok(SatisfiedProgram { simplicity: simplicity_redeem, debug_symbols: self.debug_symbols.clone(), @@ -774,10 +810,17 @@ pub(crate) mod tests { self, witness_values: WitnessValues, ) -> TestCase { - let program = match self.program.satisfy(witness_values) { - Ok(x) => x, - Err(error) => panic!("{error}"), - }; + // Test programs may only supply the witnesses for the branch that executes + // (e.g. `last_will`), so the harness zero-fills the rest. The strict, + // fail-on-missing path is exercised directly in `satisfy`-based unit tests. + let program = + match self + .program + .satisfy_with_env_fill_missing(witness_values, None, true) + { + Ok(x) => x, + Err(error) => panic!("{error}"), + }; TestCase { program, lock_time: self.lock_time, @@ -1140,6 +1183,182 @@ pub(crate) mod tests { .assert_run_success(); } + #[test] + fn enum_match_3_variants() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 0, B = 2, C = 5 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Select variant C via its u8 discriminant. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(5)); + TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .assert_run_success(); + } + + #[test] + fn enum_match_invalid_discriminant_fails() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" + enum Path { A = 1, B = 2, C = 3 } + fn main() { + match witness::PATH { + Path::A => assert!(jet::eq_32(0, 0)), + Path::B => assert!(jet::eq_32(0, 0)), + Path::C => assert!(jet::eq_32(0, 0)), + } + } + "#; + // Discriminant 0 is not declared in the enum; the script must fail. + for bad in [0u8, 4, 99, 255] { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("PATH"), Value::u8(bad)); + let result = TestCase::program_text(Cow::Borrowed(src)) + .with_witness_values(WitnessValues::from(map)) + .run(); + assert!( + result.is_err(), + "discriminant {bad} is not declared; execution should fail but succeeded" + ); + } + } + + #[test] + fn missing_witness_on_live_branch_errors() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" +enum Branch { A = 1, B = 2 } +fn main() { + match witness::SELECTOR { + Branch::A => assert!(jet::is_zero_32(witness::A)), + Branch::B => assert!(jet::is_zero_32(witness::B)), + } +} +"#; + let env = crate::dummy_env::dummy(); + + // SELECTOR = 1 (Branch::A) → branch A taken; B is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(1)); + map.insert(WitnessName::from_str_unchecked("A"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env_fill_missing(WitnessValues::from(map), Some(&env), true) + .expect("B is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; A is missing but pruned → satisfy OK + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + map.insert(WitnessName::from_str_unchecked("B"), Value::u32(0)); + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + compiled + .satisfy_with_env_fill_missing(WitnessValues::from(map), Some(&env), true) + .expect("A is on a pruned branch; satisfy should succeed"); + } + + // SELECTOR = 2 (Branch::B) → branch B taken; B is missing and live → satisfy errors + { + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(2)); + // B is intentionally not provided + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + let err = compiled + .satisfy_with_env_fill_missing(WitnessValues::from(map), Some(&env), true) + .expect_err("B is on the executed branch and missing; satisfy should fail"); + assert!( + err.contains('B'), + "error message should mention witness B, got: {err}" + ); + } + } + + #[test] + fn fill_missing_false_rejects_missing_witness() { + use crate::str::WitnessName; + use crate::value::ValueConstructible; + use std::collections::HashMap; + + let src = r#" +enum Branch { A = 1, B = 2 } +fn main() { + match witness::SELECTOR { + Branch::A => assert!(jet::is_zero_32(witness::A)), + Branch::B => assert!(jet::is_zero_32(witness::B)), + } +} +"#; + let compiled = CompiledProgram::new( + src, + Arguments::default(), + false, + Box::new(ElementsJetHinter::new()), + ) + .unwrap(); + + // Only SELECTOR and A are provided; B is omitted. With `fill_missing = false` and no + // pruning env, an omitted witness must be an error rather than being silently + // zero-filled. + let mut map: HashMap = HashMap::new(); + map.insert(WitnessName::from_str_unchecked("SELECTOR"), Value::u8(1)); + map.insert(WitnessName::from_str_unchecked("A"), Value::u32(0)); + + let err = compiled + .satisfy_with_env_fill_missing(WitnessValues::from(map.clone()), None, false) + .expect_err("missing witness B must be rejected when fill_missing = false"); + assert!( + err.contains('B'), + "error should mention the missing witness B, got: {err}" + ); + + // The public `satisfy` is strict (fill_missing = false), so it rejects the same input. + compiled + .satisfy(WitnessValues::from(map.clone())) + .expect_err("public satisfy must reject a missing witness"); + + // Same inputs with `fill_missing = true` succeed (B is zero-filled). + compiled + .satisfy_with_env_fill_missing(WitnessValues::from(map), None, true) + .expect("fill_missing = true should zero-fill the omitted witness B"); + } + #[test] #[cfg(feature = "serde")] fn hodl_vault() { diff --git a/src/named.rs b/src/named.rs index 9de1b6e7..c4ad36d3 100644 --- a/src/named.rs +++ b/src/named.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::sync::Arc; use simplicity::dag::{InternalSharing, PostOrderIterItem}; @@ -243,6 +244,67 @@ pub fn populate_witnesses( node.convert::(&mut populator) } +/// Walk the `commit` tree and the `pruned` redeem tree in parallel, checking that +/// no zero-filled witness (tracked in `zero_filled`) appears on a non-pruned branch. +/// +/// Pruned branches are indicated by `Fail` nodes in the pruned tree. When a `Case` +/// node is pruned to `AssertL` or `AssertR`, only the surviving child is recursed into. +pub fn check_surviving_witnesses( + commit: &CommitNode, + pruned: &Arc, + zero_filled: &HashSet, +) -> Result<(), String> { + match (commit.inner(), pruned.inner()) { + // Pruned branch or unreachable fail node — no witnesses to check + (_, Inner::Fail(_)) | (Inner::Fail(_), _) => Ok(()), + // Witness node on a live branch — error if it was zero-filled + (Inner::Witness(name), Inner::Witness(_)) => { + if zero_filled.contains(name) { + Err(format!( + "Witness `{name}` is used on the executed branch but has no assigned value" + )) + } else { + Ok(()) + } + } + // Leaf nodes with no witness children + (Inner::Iden, _) | (Inner::Unit, _) | (Inner::Jet(_), _) | (Inner::Word(_), _) => Ok(()), + // Single-child nodes — recurse into the child + (Inner::InjL(cc), Inner::InjL(cp)) + | (Inner::InjR(cc), Inner::InjR(cp)) + | (Inner::Take(cc), Inner::Take(cp)) + | (Inner::Drop(cc), Inner::Drop(cp)) => check_surviving_witnesses(cc, cp, zero_filled), + // Assert nodes — one live child, one CMR; recurse into the live child + (Inner::AssertL(cc, _), Inner::AssertL(cp, _)) + | (Inner::AssertR(_, cc), Inner::AssertR(_, cp)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + // Two-child nodes — recurse into both + (Inner::Comp(cl, cr), Inner::Comp(pl, pr)) | (Inner::Pair(cl, cr), Inner::Pair(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case: both branches live + (Inner::Case(cl, cr), Inner::Case(pl, pr)) => { + check_surviving_witnesses(cl, pl, zero_filled)?; + check_surviving_witnesses(cr, pr, zero_filled) + } + // Case pruned to AssertL: only left branch survived + (Inner::Case(cl, _), Inner::AssertL(pl, _)) => { + check_surviving_witnesses(cl, pl, zero_filled) + } + // Case pruned to AssertR: only right branch survived + (Inner::Case(_, cr), Inner::AssertR(_, pr)) => { + check_surviving_witnesses(cr, pr, zero_filled) + } + // Disconnect — not used in SimplicityHL; handle defensively + (Inner::Disconnect(cc, _), Inner::Disconnect(cp, _)) => { + check_surviving_witnesses(cc, cp, zero_filled) + } + _ => unreachable!("unexpected structural mismatch between commit and pruned trees"), + } +} + // This awkward construction is required by rust-simplicity to implement WitnessConstructible // for Node>. See // https://docs.rs/simplicity-lang/latest/simplicity/node/trait.WitnessConstructible.html#foreign-impls diff --git a/src/parse.rs b/src/parse.rs index 57fe3a60..cff1b1e6 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -68,6 +68,8 @@ pub enum Item { /// An import declaration (e.g., `use math::add`) that brings another /// [`Item`] into the current scope. Use(UseDecl), + /// An enum declaration. + EnumDeclaration(EnumDeclaration), /// A module containing a collection of nested [`Item`]. Module(Module), /// A placeholder used exclusively for error recovery during parsing. @@ -82,6 +84,7 @@ impl_require_feature!(Item { TypeAlias(alias), Function(function), Use(use_decl), + EnumDeclaration(_), Module(module), Ignored, }); @@ -460,6 +463,91 @@ impl_eq_hash!(TypeAlias; name, ty); impl_require_feature!(TypeAlias { recurse: ty; }); +/// A single variant in an enum declaration. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct EnumVariant { + name: Identifier, + discriminant: u8, + span: Span, +} + +impl EnumVariant { + pub fn name(&self) -> &Identifier { + &self.name + } + + pub fn discriminant(&self) -> u8 { + self.discriminant + } +} + +impl AsRef for EnumVariant { + fn as_ref(&self) -> &Span { + &self.span + } +} + +/// An enum declaration. +#[derive(Clone, Debug)] +pub struct EnumDeclaration { + file_id: usize, + visibility: Visibility, + name: AliasName, + variants: Arc<[EnumVariant]>, + span: Span, +} + +impl EnumDeclaration { + pub fn file_id(&self) -> usize { + self.file_id + } + + pub fn set_file_id(&mut self, file_id: usize) { + self.file_id = file_id; + } + + pub fn visibility(&self) -> &Visibility { + &self.visibility + } + + pub fn name(&self) -> &AliasName { + &self.name + } + + pub fn variants(&self) -> &[EnumVariant] { + &self.variants + } +} + +impl_eq_hash!(EnumDeclaration; name, variants); + +impl AsRef for EnumDeclaration { + fn as_ref(&self) -> &Span { + &self.span + } +} + +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for EnumDeclaration { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let file_id = u.int_in_range(0..=3)?; + let visibility = Visibility::arbitrary(u)?; + let name = AliasName::arbitrary(u)?; + let len = u.int_in_range(2..=8)?; + let variants = (0..len) + .map(|_| EnumVariant::arbitrary(u)) + .collect::>>()?; + Ok(Self { + file_id, + visibility, + name, + variants, + span: Span::DUMMY, + }) + } +} + /// An expression is something that returns a value. #[derive(Clone, Debug)] pub struct Expression { @@ -572,6 +660,8 @@ pub enum SingleExpressionInner { Expression(Arc), /// Match expression over a sum type Match(Match), + /// Match expression over a named enum type + EnumMatch(EnumMatch), /// Tuple wrapper expression Tuple(Arc<[Expression]>), /// Array wrapper expression @@ -596,6 +686,7 @@ impl_require_feature!(SingleExpressionInner { Call(call), Expression(expr), Match(match_), + EnumMatch(enum_match), Tuple(exprs), Array(exprs), List(exprs), @@ -648,6 +739,32 @@ impl_eq_hash!(Match; scrutinee, left, right); impl_require_feature!(Match {recurse: scrutinee, left, right; }); +/// Match expression over a named enum type (N arms, N ≥ 2). +#[derive(Clone, Debug)] +pub struct EnumMatch { + scrutinee: Arc, + arms: Arc<[MatchArm]>, + span: Span, +} + +impl_require_feature!(EnumMatch {recurse: scrutinee, arms; }); + +impl EnumMatch { + pub fn scrutinee(&self) -> &Expression { + &self.scrutinee + } + + pub fn arms(&self) -> &[MatchArm] { + &self.arms + } + + pub fn span(&self) -> &Span { + &self.span + } +} + +impl_eq_hash!(EnumMatch; scrutinee, arms); + /// Arm of a match expression. #[derive(Clone, Debug, Eq, PartialEq, Hash)] pub struct MatchArm { @@ -685,6 +802,8 @@ pub enum MatchPattern { False, /// Match true value (no binding). True, + /// Match a named enum variant (no payload binding). + EnumVariant(AliasName, Identifier), } impl_require_feature!(MatchPattern { @@ -695,6 +814,7 @@ impl_require_feature!(MatchPattern { Some(pattern, ty), False, True, + EnumVariant(_, _), }); impl MatchPattern { @@ -704,7 +824,10 @@ impl MatchPattern { MatchPattern::Left(i, _) | MatchPattern::Right(i, _) | MatchPattern::Some(i, _) => { Some(i) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } @@ -714,7 +837,10 @@ impl MatchPattern { MatchPattern::Left(i, ty) | MatchPattern::Right(i, ty) | MatchPattern::Some(i, ty) => { Some((i, ty)) } - MatchPattern::None | MatchPattern::False | MatchPattern::True => None, + MatchPattern::None + | MatchPattern::False + | MatchPattern::True + | MatchPattern::EnumVariant(..) => None, } } } @@ -783,6 +909,7 @@ impl fmt::Display for Item { Self::TypeAlias(alias) => write!(f, "{alias}"), Self::Function(function) => write!(f, "{function}"), Self::Use(use_declaration) => write!(f, "{use_declaration}"), + Self::EnumDeclaration(decl) => write!(f, "{decl}"), Self::Module(module) => write!(f, "{module}"), Self::Ignored => Ok(()), } @@ -907,6 +1034,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + EnumMatch(&'a EnumMatch), } impl TreeLike for ExprTree<'_> { @@ -947,6 +1075,7 @@ impl TreeLike for ExprTree<'_> { | S::Expression(l) => Tree::Unary(Self::Expression(l)), S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::EnumMatch(enum_match) => Tree::Unary(Self::EnumMatch(enum_match)), S::Tuple(elements) | S::Array(elements) | S::List(elements) => { Tree::Nary(elements.iter().map(Self::Expression).collect()) } @@ -957,6 +1086,16 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::EnumMatch(enum_match) => Tree::Nary( + std::iter::once(Self::Expression(enum_match.scrutinee())) + .chain( + enum_match + .arms() + .iter() + .map(|arm| Self::Expression(arm.expression())), + ) + .collect(), + ), } } } @@ -1022,7 +1161,7 @@ impl fmt::Display for ExprTree<'_> { write!(f, ")")?; } }, - S::Call(..) | S::Match(..) => {} + S::Call(..) | S::Match(..) | S::EnumMatch(..) => {} S::Tuple(tuple) => { if data.n_children_yielded == 0 { write!(f, "(")?; @@ -1073,6 +1212,18 @@ impl fmt::Display for ExprTree<'_> { write!(f, ",\n}}")?; } }, + Self::EnumMatch(enum_match) => { + let n = data.n_children_yielded; + if n == 0 { + write!(f, "match ")?; + } else if n == 1 { + write!(f, "{{\n{} => ", enum_match.arms()[0].pattern())?; + } else if n <= enum_match.arms().len() { + write!(f, ",\n{} => ", enum_match.arms()[n - 1].pattern())?; + } else { + write!(f, ",\n}}")?; + } + } } } @@ -1145,7 +1296,24 @@ impl fmt::Display for MatchPattern { MatchPattern::Some(i, ty) => write!(f, "Some({i}: {ty})"), MatchPattern::False => write!(f, "false"), MatchPattern::True => write!(f, "true"), + MatchPattern::EnumVariant(enum_name, variant) => write!(f, "{enum_name}::{variant}"), + } + } +} + +impl fmt::Display for EnumDeclaration { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}enum {} {{", self.visibility(), self.name())?; + for variant in self.variants() { + write!(f, " {} = {},", variant.name(), variant.discriminant())?; } + write!(f, " }}") + } +} + +impl fmt::Display for EnumMatch { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", ExprTree::EnumMatch(self)) } } @@ -1500,9 +1668,16 @@ impl ChumskyParse for Item { let use_parser = UseDecl::parser().map(Item::Use); // Lazy item here + let enum_parser = EnumDeclaration::parser().map(Item::EnumDeclaration); let mod_parser = Module::parser_with_items(item).map(Item::Module); - choice((func_parser, use_parser, type_parser, mod_parser)) + choice(( + func_parser, + use_parser, + type_parser, + enum_parser, + mod_parser, + )) }) } } @@ -1916,6 +2091,62 @@ impl ChumskyParse for TypeAlias { } } +impl ChumskyParse for EnumDeclaration { + fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone + where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + { + let visibility = just(Token::Pub) + .to(Visibility::Public) + .or_not() + .map(Option::unwrap_or_default); + + let discriminant = just(Token::Eq) + .ignore_then(select! { Token::DecLiteral(d) => d }) + .try_map(|d, span| { + d.as_inner().parse::().map_err(|_| { + RichError::new( + Error::Grammar { + msg: format!( + "enum discriminant '{}' is out of range (must be 0-255)", + d.as_inner() + ), + }, + span, + ) + }) + }); + + let variant = + Identifier::parser() + .then(discriminant) + .map_with(|(name, discriminant), e| EnumVariant { + name, + discriminant, + span: e.span(), + }); + + let variants = variant + .separated_by(just(Token::Comma)) + .allow_trailing() + .collect::>() + .delimited_by(just(Token::LBrace), just(Token::RBrace)) + .map(Arc::from); + + visibility + .then_ignore(just(Token::Enum)) + .then(AliasName::parser()) + .then(variants) + .map_with(|((visibility, name), variants), e| Self { + file_id: MAIN_MODULE, + visibility, + name, + variants, + span: e.span(), + }) + } +} + impl ChumskyParse for Expression { fn parser<'tokens, 'src: 'tokens, I>() -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone where @@ -2037,7 +2268,7 @@ impl SingleExpression { let call = Call::parser(expr.clone()).map(SingleExpressionInner::Call); - let match_expr = Match::parser(expr.clone()).map(SingleExpressionInner::Match); + let match_expr = match_expr_parser(expr.clone()); let variable = Identifier::parser().map(SingleExpressionInner::Variable); @@ -2092,129 +2323,120 @@ impl ChumskyParse for MatchPattern { } } -impl MatchArm { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - MatchPattern::parser() - .then_ignore(just(Token::FatArrow)) - .then(expr.map(Arc::new)) - .then(just(Token::Comma).or_not()) - .validate(|((pattern, expression), comma), e, emitter| { - let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); - - if !is_block && comma.is_none() { - emitter.emit( - Error::Grammar { - msg: "Missing ',' after a match arm that isn't block expression" - .to_string(), - } - .with_span(e.span()), - ); - } - - Self { - pattern, - expression, - } - }) - } -} - -impl Match { - fn parser<'tokens, 'src: 'tokens, I, E>( - expr: E, - ) -> impl Parser<'tokens, I, Self, ParseError<'src>> + Clone - where - I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, - E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, - { - let scrutinee = expr.clone().map(Arc::new); - - let arm_recovery = any() - .filter(|t| !matches!(t, Token::Comma | Token::RBrace)) - .ignored() - .or(nested_delimiters( - Token::LBrace, - Token::RBrace, - [ - (Token::LParen, Token::RParen), - (Token::LBracket, Token::RBracket), - ], - |_| (), - ) - .ignored()) - .repeated() - .map_with(|(), _| None); - - let arm_parser = MatchArm::parser(expr.clone()) - .map(Some) - .recover_with(via_parser(arm_recovery.clone())); - - let arms = delimited_with_recovery( - arm_parser.clone().then(arm_parser.clone()), - Token::LBrace, - Token::RBrace, - |_| (None, None), - ); +/// Parser for `match` expressions. +/// +/// Handles both binary match (exactly 2 arms: Left/Right, None/Some, false/true) and enum match +/// (2+ arms using `EnumName::Variant` patterns). Dispatches to [`Match`] or [`EnumMatch`] based +/// on the patterns found. +fn match_expr_parser<'tokens, 'src: 'tokens, I, E>( + expr: E, +) -> impl Parser<'tokens, I, SingleExpressionInner, ParseError<'src>> + Clone +where + I: ValueInput<'tokens, Token = Token<'src>, Span = Span>, + E: Parser<'tokens, I, Expression, ParseError<'src>> + Clone + 'tokens, +{ + let scrutinee = expr.clone().map(Arc::new); - just(Token::Match) - .ignore_then(scrutinee) - .then(arms) - .validate(|(scrutinee, arms), e, emit| match arms { - (Some(first), Some(second)) => { - let (left, right) = match (&first.pattern, &second.pattern) { - (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), - (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), - - (MatchPattern::None, MatchPattern::Some(..)) => (first, second), - (MatchPattern::Some(..), MatchPattern::None) => (second, first), - - (MatchPattern::False, MatchPattern::True) => (first, second), - (MatchPattern::True, MatchPattern::False) => (second, first), - - (p1, p2) => { - emit.emit( - Error::IncompatibleMatchArms { - first: p1.clone(), - second: p2.clone(), - } - .with_span(e.span()), - ); - (first, second) - } - }; + // Enum variant pattern: `EnumName::VariantName`. + // Binary keywords are excluded so choice() works without backtracking: + // when the ident is Left/Right/Some/None the select! guard fails without consuming the token. + let enum_variant_pattern = + select! { Token::Ident(name) if name != "Left" && name != "Right" && name != "Some" && name != "None" => AliasName::from_str_unchecked(name) } + .then_ignore(just(Token::DoubleColon)) + .then(select! { Token::Ident(v) => Identifier::from_str_unchecked(v) }) + .map(|(enum_name, variant)| MatchPattern::EnumVariant(enum_name, variant)); + + let combined_pattern = choice((enum_variant_pattern, MatchPattern::parser())); + + // No recover_with here: repeated() stops naturally when arm_parser fails. + // Outer delimited_with_recovery handles the block-level recovery. + let arm_parser = combined_pattern + .then_ignore(just(Token::FatArrow)) + .then(expr.clone().map(Arc::new)) + .then(just(Token::Comma).or_not()) + .validate(|((pattern, expression), comma), e, emitter| { + let is_block = matches!(expression.as_ref().inner, ExpressionInner::Block(_, _)); + if !is_block && comma.is_none() { + emitter.emit( + Error::Grammar { + msg: "Missing ',' after a match arm that isn't block expression" + .to_string(), + } + .with_span(e.span()), + ); + } + MatchArm { + pattern, + expression, + } + }); + + let arms = delimited_with_recovery( + arm_parser.repeated().collect::>(), + Token::LBrace, + Token::RBrace, + |_| vec![], + ); + + just(Token::Match) + .ignore_then(scrutinee) + .then(arms) + .validate(|(scrutinee, arms), e, emit| { + let all_enum = arms + .iter() + .all(|a| matches!(a.pattern, MatchPattern::EnumVariant(..))); + + if all_enum && arms.len() >= 2 { + return SingleExpressionInner::EnumMatch(EnumMatch { + scrutinee, + arms: Arc::from(arms), + span: e.span(), + }); + } - Self { - scrutinee, - left, - right, - span: e.span(), + // Binary match: exactly 2 non-enum arms. + let fallback_arm = MatchArm { + expression: Arc::new(Expression::empty(Span::DUMMY)), + pattern: MatchPattern::False, + }; + let (first, second) = if arms.len() == 2 { + let mut it = arms.into_iter(); + (it.next().unwrap(), it.next().unwrap()) + } else { + emit.emit( + Error::Grammar { + msg: "binary match requires exactly 2 arms".to_string(), } - } - _ => { - let match_arm_fallback = MatchArm { - expression: Arc::new(Expression::empty(Span::DUMMY)), - pattern: MatchPattern::False, - }; + .with_span(e.span()), + ); + (fallback_arm.clone(), fallback_arm) + }; - let (left, right) = ( - arms.0.unwrap_or(match_arm_fallback.clone()), - arms.1.unwrap_or(match_arm_fallback.clone()), + let (left, right) = match (&first.pattern, &second.pattern) { + (MatchPattern::Left(..), MatchPattern::Right(..)) => (first, second), + (MatchPattern::Right(..), MatchPattern::Left(..)) => (second, first), + (MatchPattern::None, MatchPattern::Some(..)) => (first, second), + (MatchPattern::Some(..), MatchPattern::None) => (second, first), + (MatchPattern::False, MatchPattern::True) => (first, second), + (MatchPattern::True, MatchPattern::False) => (second, first), + (p1, p2) => { + emit.emit( + Error::IncompatibleMatchArms { + first: p1.clone(), + second: p2.clone(), + } + .with_span(e.span()), ); - Self { - scrutinee, - left, - right, - span: e.span(), - } + (first, second) } + }; + SingleExpressionInner::Match(Match { + scrutinee, + left, + right, + span: e.span(), }) - } + }) } impl Module { @@ -2785,4 +3007,102 @@ fn main() { assert_eq!(program.to_string(), format!("{input}\n")); } } + + fn parse_item(input: &str) -> Item { + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + program.items().first().expect("expected one item").clone() + } + + #[test] + fn test_enum_declaration_basic() { + let item = parse_item("enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration, got {item:?}"); + }; + assert_eq!(decl.name().as_inner(), "Path"); + assert_eq!(decl.variants().len(), 3); + assert_eq!(decl.variants()[0].name().as_inner(), "Inherit"); + assert_eq!(decl.variants()[0].discriminant(), 1); + assert_eq!(decl.variants()[1].name().as_inner(), "ColdSpend"); + assert_eq!(decl.variants()[1].discriminant(), 2); + assert_eq!(decl.variants()[2].name().as_inner(), "RefreshSpend"); + assert_eq!(decl.variants()[2].discriminant(), 3); + } + + #[test] + fn test_enum_declaration_pub() { + let item = parse_item("pub enum Color { Red = 0, Green = 1, Blue = 2, }"); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!(decl.visibility(), &Visibility::Public); + assert_eq!(decl.name().as_inner(), "Color"); + } + + #[test] + fn test_enum_declaration_display_round_trip() { + let input = "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }"; + let item = parse_item(input); + let Item::EnumDeclaration(decl) = item else { + panic!("expected EnumDeclaration"); + }; + assert_eq!( + decl.to_string(), + "enum Path { Inherit = 1, ColdSpend = 2, RefreshSpend = 3, }" + ); + } + + #[test] + fn test_enum_match_parses() { + let input = "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, Path::RefreshSpend => 2, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors( + MAIN_MODULE, + source, + &UnstableFeatures::all(), + &mut errors, + ); + assert!(program.is_some(), "should parse without errors"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } + + #[test] + fn test_enum_match_produces_enum_match_node() { + let input = + "fn main() { match witness::PATH { Path::Inherit => 0, Path::ColdSpend => 1, } }"; + let program = parse::Program::parse_from_str(input).expect("parsing should succeed"); + // Walk the tree looking for an EnumMatch node + let has_enum_match = program.items().iter().any(|item| { + if let Item::Function(f) = item { + format!("{f}").contains("Path::Inherit") + } else { + false + } + }); + assert!(has_enum_match, "expected EnumMatch in the parse tree"); + } + + #[test] + fn test_binary_match_still_works_after_enum_parser_change() { + let input = "fn main() { let x: bool = true; match x { true => 1, false => 0, } }"; + let source = SourceFile::anonymous(Arc::from(input)); + let mut errors = ErrorCollector::new(); + let program = Program::parse_from_str_with_errors( + MAIN_MODULE, + source, + &UnstableFeatures::all(), + &mut errors, + ); + assert!(program.is_some(), "binary match should still parse"); + assert!( + !errors.has_errors(), + "unexpected errors: {}", + ErrorCollector::to_string(&errors) + ); + } } diff --git a/src/value.rs b/src/value.rs index 1ccb38bd..47df6a4a 100644 --- a/src/value.rs +++ b/src/value.rs @@ -648,6 +648,34 @@ impl Value { }; Ok(ret) } + + /// Create a zero value of the given type. + /// + /// For integers, this is 0. For sum types, this is `Left(zero)`. For options, this is `None`. + /// For tuples and arrays, each element is zero. For lists, this is the empty list. + pub fn zero(ty: &ResolvedType) -> Self { + match ty.as_inner() { + TypeInner::Boolean => Self::from(false), + TypeInner::UInt(uint_ty) => match uint_ty { + UIntType::U1 => Self::u1(0), + UIntType::U2 => Self::u2(0), + UIntType::U4 => Self::u4(0), + UIntType::U8 => Self::u8(0), + UIntType::U16 => Self::u16(0), + UIntType::U32 => Self::u32(0), + UIntType::U64 => Self::u64(0), + UIntType::U128 => Self::u128(0), + UIntType::U256 => Self::u256(U256::from_byte_array([0u8; 32])), + }, + TypeInner::Either(left, right) => Self::left(Self::zero(left), (**right).clone()), + TypeInner::Option(inner) => Self::none((**inner).clone()), + TypeInner::Tuple(elements) => Self::tuple(elements.iter().map(|e| Self::zero(e))), + TypeInner::Array(el_ty, size) => { + Self::array((0..*size).map(|_| Self::zero(el_ty)), (**el_ty).clone()) + } + TypeInner::List(el_ty, bound) => Self::list([], (**el_ty).clone(), *bound), + } + } } impl Value { diff --git a/src/witness.rs b/src/witness.rs index ae8b0581..3d0178ba 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::sync::Arc; @@ -128,6 +128,23 @@ impl WitnessValues { Ok(()) } + + /// Return a copy of these witness values with zero values inserted for any witness declared + /// in `types` that has no assigned value. Witnesses already present are unchanged. + /// + /// This is used before populating Simplicity witness nodes: all nodes must be filled, even + /// those on branches that will be pruned away and never executed. + pub fn fill_missing(&self, types: &WitnessTypes) -> (Self, HashSet) { + let mut map: HashMap = (*self.0).clone(); + let mut zero_filled = HashSet::new(); + for (name, ty) in types.iter() { + if !map.contains_key(name) { + map.insert(name.shallow_clone(), Value::zero(ty)); + zero_filled.insert(name.shallow_clone()); + } + } + (Self::from(map), zero_filled) + } } impl ParseFromStr for ResolvedType { @@ -216,7 +233,7 @@ mod tests { use crate::ast::ElementsJetHinter; use crate::parse::ParseFromStr; use crate::value::ValueConstructible; - use crate::{ast, parse, CompiledProgram, SatisfiedProgram}; + use crate::{ast, parse, CompiledProgram, ResolvedType, SatisfiedProgram}; #[test] fn witness_reuse() { @@ -282,6 +299,45 @@ fn main() { } } + #[test] + fn fill_missing_zero_fills_and_tracks_missing_witnesses() { + let ty = ResolvedType::parse_from_str("u32").unwrap(); + let witness_types = WitnessTypes::from(HashMap::from([ + (WitnessName::from_str_unchecked("A"), ty.clone()), + (WitnessName::from_str_unchecked("B"), ty.clone()), + (WitnessName::from_str_unchecked("C"), ty.clone()), + ])); + + // A is explicitly provided with value zero (same value fill_missing would insert). + // B and C are not provided at all. + let provided = WitnessValues::from(HashMap::from([( + WitnessName::from_str_unchecked("A"), + Value::u32(0), + )])); + + let (filled, zero_filled) = provided.fill_missing(&witness_types); + + // Explicitly-provided witnesses must NOT be tracked as zero-filled, + // even when their value happens to be zero. + assert!( + !zero_filled.contains(&WitnessName::from_str_unchecked("A")), + "A was explicitly provided; must not appear in zero_filled" + ); + // Missing witnesses must be tracked so check_surviving_witnesses can error. + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("B")), + "B was not provided; must appear in zero_filled" + ); + assert!( + zero_filled.contains(&WitnessName::from_str_unchecked("C")), + "C was not provided; must appear in zero_filled" + ); + // All three must now have values in the filled map. + assert!(filled.get(&WitnessName::from_str_unchecked("A")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("B")).is_some()); + assert!(filled.get(&WitnessName::from_str_unchecked("C")).is_some()); + } + #[test] fn witness_to_string() { let witness = WitnessValues::from(HashMap::from([ diff --git a/test-data/last_will.json b/test-data/last_will.json index 3dace1a8..e54d5914 100644 --- a/test-data/last_will.json +++ b/test-data/last_will.json @@ -1,4 +1,4 @@ { - "program": "5wnQKEGJsWVABAmKSEGCrynMGLpUF69BbvwQFoAuY+y1ngQJfqSPabfWRZ9K3F2jdRYYBitLzfMz987l3WKtAxSudDhYOBTf5tlucUbKz5QK2LfAvMA1kChBh+DHCpJAk4cziqISK6EzABFXCwYvClhPYGFQusJfripGQssOAVt34AhgGJAoSQbgJxuBig/FJwqFobGHNddy8HoTqejIHGcv8bcleUZT57KmW1Vp7LXaMUR4qMQ4YBiE3n41BAOBgcOJFAGQOJwuLAGkHHAHHpBiBQbkHacYEf5RB7X1tMEVAbpXAfNhcd45LjO88p6usCblccJ7lByDCchhcRcOA4GJxgBwcGIGlafkwigGSWMMQSTRPhfidUim1MchFg2+ZsIYB8RO84Db5ByMCcj0kCT4YnM/BVazZBMsdgY/lS0WYcYfsNRJVmhtHQmf/PVrNEOe4wYBisgAAAAIGkKhambcTmCIv9QHGkTTAYXN78l9PRKwkaP7L+QgG2hgCZadW734oMAxC4AwcwLvfahR91ofRxdIEhoraXiTCljMruIwlAG9G26fy7ABhgGL4cEObxOhhI4BnviN4uejwVYdGCizvg8pDe+f7r9U2pQklHLAwDhITlckiyAAAAAAUKvhqObQQ6CxT1VyVCCLZUfrJhqhp/qbNkpewATHlLgTDwBZJgwDELSQkUM6IZzkLPP/t8aZ/NfTm5pw5IQ0J/duRiaFMIlp35sJi5uVAYBgObvJWcC0jz5LzKXn0/Nn/OQnPezJTiq+w46I+xAB5sdEwwDWQKEkH1m5aCgWDNug5tpVnanSODCL5dAHG0YZYXL7/GhLKIDD2ZYRxW4obTlfDAMQtRvoQT9zeJ/JSg1zVnOQ+dSDqBncb+M9zPuT55FUoWyEHDxeBVkAAAAAg4AFwB8NhzSoNzoMc1Ae/7Z55CGHj0gOG/ZVBjb5kbDqY2kAJh07WGAYhcGISKEGB4nAwHMDLqyIAl/t8mPC4rRGEfM1lVH7CAxNfptKDrOKJGwAHAYBrC4iNL6lcZ9HTU3CRlRyCoGVZzuEYuSi/jp604WRZrD9L3pYeQ4FCcTgcuhcLAcvAcWAchQuRgDmABzNhOZwDmDFyvAcwoOVQHLcHleBy6B5fgc34A==", + "program": "56XQKEGJsAEECwZtoIQKD6U2AECBYANKBQfQmwAwQLABwFAoSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABDoFCDE2n5MIoBkljDEEk0T4X4nVIptTHIRYNvmbCGAfETvOA2+QQJigw+gUJIEnDmcVRCRXQmYAIq4WDF4UsJ7AwqF1hL9cVIyFlhwCtu/AEMAxIFCSD7zcBFB+AigWTYw5rruXg9CdT0ZA4zl/jbkryjKfPZUy2qtPZa7RiiPFRiHDAMQm0/IQEA4CBwwkgSfDE5n4KrWbIJljsDH8qWizDjD9hqJKs0No6Ez/56tZohz3GDAMVkAAAABA0hULUzbicwRF/qA40iaYDC5vfkvp6JWEjR/ZfyEA20MATLTq3e/FBgGIXAGDmBd77UKPutD6OLpAkNFbS8SYUsZldxGEoA3o23T+XYAMMAxfDghzeJ0MJHAM98RvFz0eCrDowUWd8HlIb3z/dfqm1KEko5YGAcJCcpEkWQAAAAAChV8NRzaCHQWKequSoQRbKj9ZMNUNP9TZslL2ACY8pcCYeALJMGAYhaSEihnRDOchZ5/9vjTP5r6c3NOHJCGhP7tyMTQphEtO/NhMXNyMDAMBzd5KzgWkefJeZS8+n5s/5yE572ZKcVX2HHRH2IAPNjomGAayBQkg+s3JgUCwZt0HNtKs7U6RwYRfLoA42jDLC5ff40JZRAYezLCOK3FDacr4YBiFqN9CCfubxP5KUGuas5yHzqQdQM7jfxnuZ9yfPIqlC2Qg4eLwKsgAAAAEHAAuAPhsOaVBudBjmoD3/bPPIQw8ekBw37KoMbfMjYdTG0gBMOnawwDELgxCRQgwPE4GA5gZdWRAEv9vkx4XFaIwj5msqo/YQGJr9NpQdZxRI2AA4DANYXERpfUrjPo6am4SMqOQVAyrOdwjFyUX8dPWnCyLNYfpe9LDyHAoTicDlCLhYDlGDiwDkKFyMAcpwcwATmBA5VC5UAOVgOYoDmPA5kgeZYDmaOJzOm5lrTjAj/KIPa+tpgioDdK4D5sLjvHJcZ3nlPV1gTcrjhPcoOZYJzMi5a8yAHLkFAzA0g7AOa84nNkbmjsWVABzRhOaZJCDBV5TmDF0qC9egt34IC0AXMfZazwIEv1JHtNvrIs+lbi7RuosMAxWl5vmZ++dy7rFWgYpXOhwsHApv82y3OKNlZ8oFbFvgXmAayBQbmvPzbHCoDmVSKAOZoGkLcA5nQcKA4cBxADxGBztgc8w=", "witness": null } diff --git a/tests/core_tracker.rs b/tests/core_tracker.rs index 6c9b1608..a5feb307 100644 --- a/tests/core_tracker.rs +++ b/tests/core_tracker.rs @@ -3,7 +3,9 @@ use std::rc::Rc; use simplicityhl::ast::CoreJetHinter; use simplicityhl::simplicity::jet::CoreEnv; +use simplicityhl::str::WitnessName; use simplicityhl::tracker::DefaultTracker; +use simplicityhl::value::{Value, ValueConstructible}; use simplicityhl::{Arguments, TemplateProgram, WitnessValues}; const CORE_PROGRAM: &str = r#"fn main() { @@ -38,6 +40,35 @@ fn core_program_should_not_panic_with_core_tracker() { .unwrap(); } +#[test] +fn core_enum_match_compiles_for_core_target() { + // Regression: the enum-match desugaring must resolve its equality jet from the active + // jet set. When compiling for the Core target, hardcoding an Elements jet would produce + // a jet/target mismatch at commit time. Compiling all the way to a satisfied, pruned + // program proves the desugared `eq_8` is a Core jet. + const SRC: &str = r#" + enum Coin { Heads = 1, Tails = 2 } + fn main() { + match witness::C { + Coin::Heads => assert!(jet::eq_32(0, 0)), + Coin::Tails => assert!(jet::eq_32(0, 0)), + } + } + "#; + + let mut map = std::collections::HashMap::new(); + map.insert(WitnessName::from_str_unchecked("C"), Value::u8(1)); + + let satisfied = TemplateProgram::new(SRC, Box::new(CoreJetHinter::new())) + .unwrap() + .instantiate(Arguments::default(), true) + .unwrap() + .satisfy(WitnessValues::from(map)) + .unwrap(); + + satisfied.redeem().prune(&CoreEnv::new()).unwrap(); +} + #[test] fn core_program_traces_core_jets() { let satisfied = satisfied_core_program();