Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,85 @@ impl<'tcx> Analyzer<'tcx> {
ensure_annot
}

/// Collects every `#[thrust::refinement_path(..)]` path statement in the
/// function body, returning each `(type position, formula_fn DefId)`.
fn extract_refinement_paths(
&self,
local_def_id: LocalDefId,
) -> Vec<(rty::TypePosition, DefId)> {
let mut out = Vec::new();
let Some(body) = self.tcx.hir_maybe_body_owned_by(local_def_id) else {
return out;
};
let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
return out;
};
let attr_path = analyze::annot::refinement_path_path();
let typeck = self.tcx.typeck(local_def_id);
for stmt in block.stmts {
let Some(attr) = self
.tcx
.hir_attrs(stmt.hir_id)
.iter()
.find(|attr| attr.path_matches(&attr_path))
else {
continue;
};
let ts = analyze::annot::extract_annot_tokens(attr.clone());
let position = analyze::annot::parse_type_position(&ts);

let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
self.tcx.dcx().span_err(
stmt.span,
"annotated path is expected to be a semi statement",
);
continue;
};
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to be a path expression",
);
continue;
};
let rustc_hir::def::Res::Def(_, def_id) = typeck.qpath_res(&qpath, expr.hir_id) else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to refer to a definition",
);
continue;
};
out.push((position, def_id));
}
out
}

/// Resolves every `#[thrust::refinement_path(..)]` annotation into a
/// positioned refinement, by translating the referenced formula function.
pub fn extract_refinement_annots(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Vec<(rty::TypePosition, rty::Refinement<rty::FunctionParamIdx>)> {
let mut out = Vec::new();
for (position, def_id) in self.extract_refinement_paths(local_def_id) {
let Some(formula_def_id) = def_id.as_local() else {
panic!(
"refinement_path annotation is expected to refer to a local def, but found: {:?}",
def_id
);
};
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
panic!(
"refinement_path annotation {:?} is not a formula function",
formula_def_id
);
};
out.push((position, formula_fn.to_refinement()));
}
out
}

/// Whether the given `def_id` corresponds to a method of one of the `Fn` traits.
fn is_fn_trait_method(&self, def_id: DefId) -> bool {
self.tcx
Expand Down
57 changes: 57 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pub fn ensures_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("ensures_path")]
}

pub fn refinement_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("refinement_path")]
}

pub fn model_ty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -207,6 +211,59 @@ pub fn extract_annot_tokens(attr: Attribute) -> TokenStream {
d.tokens
}

/// Parses a [`rty::TypePosition`] from the tokens of a
/// `#[thrust::refinement_path(..)]` attribute.
///
/// Tokens are comma-separated [`rty::TypePositionStep`]s, each encoded as
/// `result` (→ `Return`), `$i` (→ `Param(i)`), or a bare integer `i` (→
/// `TypeArg(i)`).
pub fn parse_type_position(ts: &TokenStream) -> rty::TypePosition {
use rustc_ast::token::{LitKind, TokenKind};
use rustc_ast::tokenstream::TokenTree;

let parse_int = |lit: &rustc_ast::token::Lit| -> usize {
assert_eq!(
lit.kind,
LitKind::Integer,
"expected an integer in type position"
);
lit.symbol
.as_str()
.parse()
.expect("invalid integer in type position")
};

let mut steps = Vec::new();
let mut iter = ts.iter();
while let Some(tt) = iter.next() {
let TokenTree::Token(t, _) = tt else {
panic!("unexpected token tree in type position");
};
match &t.kind {
TokenKind::Comma => {}
TokenKind::Ident(sym, _) if sym.as_str() == "result" => {
steps.push(rty::TypePositionStep::Return);
}
TokenKind::Dollar => {
let i = match iter.next() {
Some(TokenTree::Token(t, _)) => match &t.kind {
TokenKind::Literal(lit) => parse_int(lit),
_ => panic!("expected integer after `$` in type position: {:?}", t),
},
_ => panic!("expected integer after `$` in type position"),
};
steps.push(rty::TypePositionStep::Param(rty::FunctionParamIdx::from(i)));
}
TokenKind::Literal(lit) => {
steps.push(rty::TypePositionStep::TypeArg(parse_int(lit)));
}
_ => panic!("unexpected token in type position: {:?}", t),
}
}

rty::TypePosition::new(steps)
}

pub fn split_param(ts: &TokenStream) -> (Ident, TokenStream) {
use rustc_ast::token::TokenKind;
use rustc_ast::tokenstream::TokenTree;
Expand Down
28 changes: 28 additions & 0 deletions src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ impl<'tcx> FormulaFn<'tcx> {
AnnotFormula::Formula(self.formula.clone())
}

/// Lowers an `ensures` formula function into a postcondition annotation.
///
/// Relies on the layout produced by `thrust_macros::ensures`: parameter `0`
/// is the function's return value (bound to [`rty::RefinedTypeVar::Value`])
/// and parameters `1..n` are the enclosing function's parameters in order
/// (mapped to [`rty::RefinedTypeVar::Free`]).
pub fn to_ensure_annot(&self) -> AnnotFormula<rty::RefinedTypeVar<rty::FunctionParamIdx>> {
AnnotFormula::Formula(self.formula.clone().map_var(|v| {
if v.as_usize() == 0 {
Expand All @@ -53,6 +59,28 @@ impl<'tcx> FormulaFn<'tcx> {
}
}))
}

/// Lowers a refinement-type formula function (generated by the `param` /
/// `ret` / `sig` macros) into a [`rty::Refinement`] on the enclosing
/// function's parameters.
///
/// Relies on the layout produced by those macros: parameter `0` is the
/// refinement's value binder (bound to [`rty::RefinedTypeVar::Value`] at
/// the type position where the refinement is installed) and parameters
/// `1..n` are the enclosing function's parameters in order (mapped to
/// [`rty::RefinedTypeVar::Free`]).
pub fn to_refinement(&self) -> rty::Refinement<rty::FunctionParamIdx> {
self.formula
.clone()
.map_var(|v| {
if v.as_usize() == 0 {
rty::RefinedTypeVar::Value
} else {
rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1))
}
})
.into()
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
7 changes: 7 additions & 0 deletions src/analyze/local_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
assert!(require_annot.is_none() || param_annots.is_empty());
assert!(ensure_annot.is_none() || ret_annot.is_none());

let refinement_annots = self
.ctx
.extract_refinement_annots(self.local_def_id, self.generic_args);

let trait_item_ty = self.trait_item_ty();
let is_fully_annotated = self.is_fully_annotated();

Expand All @@ -431,6 +435,9 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
if let Some(ret_rty) = ret_annot {
builder.ret_rty(ret_rty);
}
for (position, refinement) in refinement_annots {
builder.refinement_at(&position, refinement);
}

if is_fully_annotated {
rty::RefinedType::unrefined(builder.build().into())
Expand Down
43 changes: 43 additions & 0 deletions src/refine/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,49 @@ impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R> {
self.ret_rty = Some(rty);
self
}

/// Records a refinement to install at a [`rty::TypePosition`].
///
/// The first step must be [`rty::TypePositionStep::Param`] or
/// [`rty::TypePositionStep::Return`]; the remaining steps are forwarded to
/// [`rty::RefinedType::install_refinement_at`].
pub fn refinement_at(
&mut self,
position: &rty::TypePosition,
refinement: rty::Refinement<rty::FunctionParamIdx>,
) -> &mut Self {
let (first, rest) = match position.steps().split_first() {
Some(pair) => pair,
None => panic!("type position applied to a function type must not be empty"),
};
match first {
rty::TypePositionStep::Param(idx) => {
if !self.param_rtys.contains_key(idx) {
let ty = self.inner.build(self.param_tys[idx.index()].ty).vacuous();
self.param_rtys
.insert(*idx, rty::RefinedType::unrefined(ty));
}
self.param_rtys
.get_mut(idx)
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::Return => {
if self.ret_rty.is_none() {
let ty = self.inner.build(self.ret_ty).vacuous();
self.ret_rty = Some(rty::RefinedType::unrefined(ty));
}
self.ret_rty
.as_mut()
.unwrap()
.install_refinement_at(rest, refinement);
}
rty::TypePositionStep::TypeArg(_) => {
panic!("type position applied to a function type must start with a param or result step, not a type argument");
}
}
self
}
}

impl<'tcx, 'a, R> FunctionTemplateTypeBuilder<'tcx, 'a, R>
Expand Down
Loading