From 13b3c6536b323d00dc02968a0bae1e081442bc89 Mon Sep 17 00:00:00 2001 From: perrydv Date: Tue, 9 Jun 2026 11:50:24 -0700 Subject: [PATCH] support ETaccessorBase as a type with its own symbol --- nCompiler/R/compile_aaa_operatorLists.R | 14 +++++ nCompiler/R/compile_generateCpp.R | 12 ++++ nCompiler/R/compile_labelAbstractTypes.R | 9 +++ nCompiler/R/compile_simpleTransformations.R | 9 +++ nCompiler/R/cppDefs_variables.R | 7 ++- nCompiler/R/symbolTable.R | 7 +-- nCompiler/R/typeDeclarations.R | 5 +- .../post_Rcpp/ETaccessor_post_Rcpp.h | 13 ++++ .../specificOp_tests/test-ETaccess-DSL.R | 59 +++++++++++++++++++ 9 files changed, 127 insertions(+), 8 deletions(-) diff --git a/nCompiler/R/compile_aaa_operatorLists.R b/nCompiler/R/compile_aaa_operatorLists.R index 86cc80ee..088ea49e 100644 --- a/nCompiler/R/compile_aaa_operatorLists.R +++ b/nCompiler/R/compile_aaa_operatorLists.R @@ -111,6 +111,20 @@ getOperatorDef <- function(op, field = NULL, subfield = NULL) { getOperatorField(opDef, field, subfield) } +assignOperatorDef( + 'ETaccess', + list( + matchDef = function(obj, copy=FALSE) {}, + compileArgs = "copy", + simpleTransformations = list( + handler = 'EvalCompileArgs'), + labelAbstractTypes = list( + handler = 'ETaccess'), + cppOutput = list( + handler = 'ETaccess') + ) +) + assignOperatorDef( 'nCppVec', list( diff --git a/nCompiler/R/compile_generateCpp.R b/nCompiler/R/compile_generateCpp.R index b691c97f..9bdd5b54 100644 --- a/nCompiler/R/compile_generateCpp.R +++ b/nCompiler/R/compile_generateCpp.R @@ -207,6 +207,18 @@ inGenCppEnv( } ) +inGenCppEnv( + ETaccess <- function(code, symTab) { + copy <- isTRUE(code$aux$compileArgs$copy) + copy_piece <- if(copy) '' else '' + paste0('ETaccessPtr', copy_piece, '(', + paste0(unlist(lapply(code$args, + compile_generateCpp, + symTab, + asArg = TRUE)), collapse=","), ")") + } +) + inGenCppEnv( Switch <- function(code, symTab) { IDs <- code$aux$compileArgs$IDs diff --git a/nCompiler/R/compile_labelAbstractTypes.R b/nCompiler/R/compile_labelAbstractTypes.R index 9a721e87..9e62a7df 100644 --- a/nCompiler/R/compile_labelAbstractTypes.R +++ b/nCompiler/R/compile_labelAbstractTypes.R @@ -265,6 +265,15 @@ inLabelAbstractTypesEnv( } ) +inLabelAbstractTypesEnv( + ETaccess <- function(code, symTab, auxEnv, handlingInfo) { + inserts <- recurse_labelAbstractTypes(code, symTab, auxEnv, + handlingInfo) + code$type <- symbolETaccBase$new(name = '') # should never be looked at because ETaccess has no return type + if(length(inserts) == 0) NULL else inserts + } +) + nCompiler:::inLabelAbstractTypesEnv( DoubleBracket <- function(code, symTab, auxEnv, handlingInfo) { # specializations from generic will have already been handled diff --git a/nCompiler/R/compile_simpleTransformations.R b/nCompiler/R/compile_simpleTransformations.R index 171bbb8e..9c6100d1 100644 --- a/nCompiler/R/compile_simpleTransformations.R +++ b/nCompiler/R/compile_simpleTransformations.R @@ -156,4 +156,13 @@ inSimpleTransformationsEnv( code$aux$compileArgs$IDs <- IDs if(code$caller$name != "{") stop("nSwitch can not be used within an expression. It does not return anything.") } +) + +inSimpleTransformationsEnv( + EvalCompileArgs <- function(code, symTab, auxEnv, info) { + for(argname in names(code$aux$compileArgs)) { + evaled_arg <- eval(code$aux$compileArgs[[argname]], envir = auxEnv$where) + code$aux$compileArgs[[argname]] <- evaled_arg + } + } ) \ No newline at end of file diff --git a/nCompiler/R/cppDefs_variables.R b/nCompiler/R/cppDefs_variables.R index 45f7f698..3db21e4d 100644 --- a/nCompiler/R/cppDefs_variables.R +++ b/nCompiler/R/cppDefs_variables.R @@ -26,7 +26,7 @@ cppVarClass <- R6::R6Class( if(length(printName) > 0) printName <- paste0(printName, collapse = ', ') cleanWhite(paste(self$baseType, - self$ptrs, + ptrs, if(isTRUE(self$ref)) '&' else @@ -195,10 +195,11 @@ cppNcppVec <- function(name = character(), templateArgs = list(elementVar)) } -cppETaccBase <- function(name = character()) { +cppETaccBase <- function(name = character(), ...) { cppVarFullClass$new(name = name, baseType = "std::unique_ptr", - templateArgs = list("ETaccessorBase")) + templateArgs = list("ETaccessorBase"), + ...) } cppEigenTensorRef <- function(name = character(), diff --git a/nCompiler/R/symbolTable.R b/nCompiler/R/symbolTable.R index 7691f878..a8bf1b86 100644 --- a/nCompiler/R/symbolTable.R +++ b/nCompiler/R/symbolTable.R @@ -499,10 +499,9 @@ symbolETaccBase <- R6::R6Class( inherit = symbolBase, portable = TRUE, public = list( - initialize = function(name, isArg = FALSE) { - self$name <- name + initialize = function(...) { + super$initialize(...) self$type <- "ETaccessorBase" - self$isArg <- isArg }, print = function() { writeLines(paste0(self$name, ': symbolETaccBase (ETaccessorBase) ')) @@ -511,7 +510,7 @@ symbolETaccBase <- R6::R6Class( paste0("ETaccessorBase") }, genCppVar = function() { - cppETaccBase(name = self$name) + cppETaccBase(name = self$name, ref = self$isArg) } ) ) diff --git a/nCompiler/R/typeDeclarations.R b/nCompiler/R/typeDeclarations.R index 1d2dc53a..16ca0666 100644 --- a/nCompiler/R/typeDeclarations.R +++ b/nCompiler/R/typeDeclarations.R @@ -366,7 +366,7 @@ typeDeclarationList <- list( ## symbolRcppType$new(RcppType = "Rcpp::Named", ...) ## }, RcppDataFrame = function(...) { - symbolRcppType$new(RcppType = "Rcpp::DataFrame") + symbolRcppType$new(RcppType = "Rcpp::DataFrame", ...) }, RcppLogicalMatrix = function(...) { symbolRcppType$new(RcppType = "Rcpp::LogicalMatrix", ...) @@ -414,6 +414,9 @@ typeDeclarationList <- list( elementSym <- type2symbol({{ttype}}, where = parent.frame()) symbolNcppVec$new(elementSym = elementSym) }, + ETaccessor = function(...) { + symbolETaccBase$new(...) + }, ## determine type from an evaluated object typeDeclarationFromObject = function(x) { if(inherits(x, 'symbolBasic')) diff --git a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h index 075b96a6..63b92c4c 100644 --- a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h +++ b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h @@ -2,6 +2,7 @@ #define NCOMPILER_ETACCESSOR_POST_RCPP_H_ #include +#include #include #include #include @@ -394,6 +395,18 @@ ETaccess(const T &x) { return ETaccessor(x); } +template +std::enable_if_t> +ETaccessPtr(T &x) { + return std::make_unique>(x); +} + +template +std::enable_if_t> +ETaccessPtr(const T &x) { + return std::make_unique>(x); +} + // end ETaccess #endif // NCOMPILER_ETACCESSOR_POST_RCPP_H_ diff --git a/nCompiler/tests/testthat/specificOp_tests/test-ETaccess-DSL.R b/nCompiler/tests/testthat/specificOp_tests/test-ETaccess-DSL.R index 582031e7..5bea0146 100644 --- a/nCompiler/tests/testthat/specificOp_tests/test-ETaccess-DSL.R +++ b/nCompiler/tests/testthat/specificOp_tests/test-ETaccess-DSL.R @@ -1,6 +1,65 @@ library(nCompiler) library(testthat) +test_that("ETaccessor type works", { + nc <- nClass( + Cpublic = list( + s = 'numericScalar', + v = 'numericVector', + m = 'numericMatrix', + get_s = nFunction( + function() { + ans <- ETaccess(s) + return(ans) + returnType('ETaccessor') + } + ), + get_inner = nFunction( + function(vn = 'string') { + ans <- self[[vn]] + return(ans) + returnType('ETaccessor') + } + ), + use = nFunction( + function(acc = 'ETaccessor') { + return(as(acc, "numericMatrix")) + returnType("numericMatrix") + } + ), + get = nFunction( + function(i = 'integerScalar', vn = 'string') { + nSwitch(i, 1:4, + eta <- get_s(), + eta <- get_inner(vn), + eta <- self[[vn]], + { + eta <- self[[vn]] + res <- use(eta) + } + ) + if(i < 4) + res <- as(eta, "numericMatrix") + return(res) + returnType("numericMatrix") + } + ) + ), + compileInfo=list(interfaceMembers = c("s","v","m", "get")) + ) + + cnc <- nCompile(nc) + obj <- cnc$new() + obj$s <- 1.2 + obj$v <- c(2.3, 3.4) + obj$m <- matrix(5:10, nrow = 3) + expect_equal(obj$get(1, "not_used"), matrix(1.2)) + expect_equal(obj$get(2, "v"), matrix(obj$v)) + expect_equal(obj$get(3, "m"), obj$m) + expect_equal(obj$get(4, "v"), matrix(obj$v)) + rm(obj); gc() +}) + test_that("obj[['x']] works like obj$x", { nc <- nClass( Cpublic = list(