Skip to content

User-defined functions in rpart

2 messages · Hugh Chipman, Torsten Hothorn

#
This question concerns rpart's facility for user-defined functions that
accomplish splitting.

I was interested in modifying the code so that in each terminal node,
a linear regression is fit to the data.

It seems that from the allowable inputs in the user-defined functions,
that this may not be possible, since they have the form:

function(y, wt, parms)  (in the case of the "evaluation" function)
function(y, wt, x, parms, continuous)  (split function)

The problem is that there seems to be no facility to include an X
matrix (in the split function, x is a vector corresponding to one
predictor).  Without that, fitting a linear model in the terminal node
would not be possible.

Is this a correct assesment, or am I missing something?
Has anyone tried to modify rpart to fit linear models in nodes?


--
#
if you just want to have a linear model in each terminal node instead of
the mean of the observations in this leaf (which does not require an
altered splitting rule), you can do something like (BostonHousing as example):

R> library(mlbench)
R> library(rpart)
### a stump only
R> tree <- rpart(medv ~ ., data=BostonHousing, cp=0.2)
R> tree
n= 506

node), split, n, deviance, yval
      * denotes terminal node

1) root 506 42716.300 22.53281
  2) rm< 6.941 430 17317.320 19.93372 *
  3) rm>=6.941 76  6059.419 37.23816 *
### fit a linear model for the observations in each leaf
R> tnodeleft <- lm(medv ~ ., data=BostonHousing, subset=(tree$where == 2))
R> tnoderight <- lm(medv ~ ., data=BostonHousing, subset=(tree$where == 3))
R> coef(tnodeleft)
  (Intercept)          crim            zn         indus         chas1
 5.291185e+01 -1.430523e-01  3.879871e-02  2.818030e-02  3.329020e+00
          nox            rm           age           dis           rad
-1.686631e+01 -3.671042e-01 -3.791325e-04 -1.128992e+00  2.954484e-01
          tax       ptratio             b         lstat
-1.071171e-02 -5.472690e-01  6.393379e-03 -5.578978e-01
R> coef(tnoderight)
  (Intercept)          crim            zn         indus         chas1
  8.224189122  -0.078029764   0.001545735   0.616441038  -0.928230180
          nox            rm           age           dis           rad
-17.080614140   6.372744363  -0.060863991  -1.032928738   0.451670727
          tax       ptratio             b         lstat
 -0.034438740  -1.647800912   0.093841671  -1.172865798
If you want to base your splitting criterion on linear models in EACH
node, you may try something like this.
I'm not sure if you need to pass the design matrix to the split function:
let Y and X denote response and design matrix in the calling environment,
than an ugly hack like:

temp2 <- function(y, wt, x, parms, continuous) {
  thisindx <- Y %in% y  ### determine which subset of the learning
                        ### sample is currently under consideration

  thisX <- X[thisindx,] ### only those in the current node

  ### get the position of the x to split in:
  myx <- which(apply(thisX, 2, function(a) all(a %in% x))])

  ### and now seach for the best split in thisX[,thisx]
  ### for each cutpoint cut ...
  lm(y ~ ., data = thisX, subset=thisX[,myx] <= cut)
  lm(y ~ ., data = thisX, subset=thisX[,myx] > cut)

... may work

just an idea,

Torsten