augment outputs a rulelist with an additional column named augmented_stats based on summary statistics calculated using attribute validation_data.

# S3 method for rulelist
augment(x, ...)

Arguments

x

A rulelist

...

(expressions) To be send to tidytable::summarise for custom aggregations. See examples.

Value

A rulelist with a new dataframe-column named augmented_stats.

Details

The dataframe-column augmented_stats will have these columns corresponding to the estimation_type:

  • For regression: support, IQR, RMSE

  • For classification: support, confidence, lift

along with custom aggregations.

See also

Examples

# Examples for augment ------------------------------------------------------
library("magrittr")

# C5 ----
att = modeldata::attrition
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE)

model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE)
tidy_c5  =
  model_c5 %>%
  tidy() %>%
  set_validation_data(att[!train_index, ], "Attrition")

tidy_c5
#> ---- Rulelist --------------------------------
#> ▶ Keys: trial_nbr
#> ▶ Number of distinct keys: 1
#> ▶ Number of rules: 23
#> ▶ Model type: C5
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#> 
#> 
#>    rule_nbr trial_nbr LHS                         RHS   support confidence  lift
#>       <int>     <int> <chr>                       <fct>   <int>      <dbl> <dbl>
#>  1        1         1 ( Age > 30 ) & ( DistanceF… No         69      0.986   1.2
#>  2        2         1 ( DistanceFromHome <= 12 )… No        149      0.960   1.1
#>  3        3         1 ( Department == 'Research_… No        211      0.953   1.1
#>  4        4         1 ( Age > 30 ) & ( DistanceF… No        249      0.948   1.1
#>  5        5         1 ( JobInvolvement %in% c('M… No        353      0.944   1.1
#>  6        6         1 ( OverTime == 'No' ) & ( S… No        263      0.943   1.1
#>  7        7         1 ( Education %in% c('Master… No        101      0.942   1.1
#>  8        8         1 ( OverTime == 'No' ) & ( R… No         95      0.938   1.1
#>  9        9         1 ( BusinessTravel %in% c('N… No        352      0.915   1.1
#> 10       10         1 ( Education %in% c('Below_… No        265      0.910   1.1
#> # ℹ 13 more rows
#> ----------------------------------------------

augment(tidy_c5) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 23
#> Columns: 10
#> $ rule_nbr                    <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,…
#> $ trial_nbr                   <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
#> $ LHS                         <chr> "( Age > 30 ) & ( DistanceFromHome <= 12 )…
#> $ RHS                         <fct> No, No, No, No, No, No, No, No, No, No, Ye…
#> $ support                     <int> 69, 149, 211, 249, 353, 263, 101, 95, 352,…
#> $ confidence                  <dbl> 0.9859155, 0.9603000, 0.9531000, 0.9482000…
#> $ lift                        <dbl> 1.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.…
#> $ augmented_stats__support    <dbl> 77, 122, 245, 282, 376, 305, 84, 111, 390,…
#> $ augmented_stats__confidence <dbl> 0.9220779, 0.9098361, 0.9346939, 0.9113475…
#> $ augmented_stats__lift       <dbl> 9.3667749, 1.0091812, 1.0367533, 1.0108577…

# augment with custom aggregator
augment(tidy_c5,output_counts = list(table(Attrition))) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 23
#> Columns: 11
#> $ rule_nbr                       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, …
#> $ trial_nbr                      <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
#> $ LHS                            <chr> "( Age > 30 ) & ( DistanceFromHome <= 1…
#> $ RHS                            <fct> No, No, No, No, No, No, No, No, No, No,…
#> $ support                        <int> 69, 149, 211, 249, 353, 263, 101, 95, 3…
#> $ confidence                     <dbl> 0.9859155, 0.9603000, 0.9531000, 0.9482…
#> $ lift                           <dbl> 1.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1,…
#> $ augmented_stats__support       <dbl> 77, 122, 245, 282, 376, 305, 84, 111, 3…
#> $ augmented_stats__confidence    <dbl> 0.9220779, 0.9098361, 0.9346939, 0.9113…
#> $ augmented_stats__output_counts <list> <<table[2]>>, <<table[2]>>, <<table[2]…
#> $ augmented_stats__lift          <dbl> 9.3667749, 1.0091812, 1.0367533, 1.0108…

# rpart ----
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE)

model_class_rpart = rpart::rpart(Species ~ ., data = iris[train_index, ])
tidy_class_rpart  = tidy(model_class_rpart) %>%
  set_validation_data(iris[!train_index, ], "Species")
tidy_class_rpart
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 3
#> ▶ Model type: rpart
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#> 
#> 
#>   rule_nbr LHS                                    RHS   support confidence  lift
#>      <int> <chr>                                  <fct>   <int>      <dbl> <dbl>
#> 1        1 ( Petal.Width < 1.7 ) & ( Petal.Lengt… vers…      28      0.967  2.53
#> 2        2 ( Petal.Width < 1.7 ) & ( Petal.Lengt… seto…      21      0.957  3.46
#> 3        3 ( Petal.Width >= 1.7 )                 virg…      27      0.931  2.72
#> ----------------------------------------------

model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ])
tidy_regr_rpart  = tidy(model_regr_rpart) %>%
  set_validation_data(iris[!train_index, ], "Sepal.Length")
tidy_regr_rpart
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: rpart
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#> 
#> 
#>   rule_nbr LHS                                                       RHS support
#>      <int> <chr>                                                   <dbl>   <int>
#> 1        1 ( Petal.Length >= 4.25 ) & ( Petal.Length < 5.85 ) & (…  6.15      23
#> 2        2 ( Petal.Length < 4.25 ) & ( Petal.Length < 1.65 )        4.99      17
#> 3        3 ( Petal.Length < 4.25 ) & ( Petal.Length >= 1.65 )       5.6       16
#> 4        4 ( Petal.Length >= 4.25 ) & ( Petal.Length < 5.85 ) & (…  6.53      12
#> 5        5 ( Petal.Length >= 4.25 ) & ( Petal.Length >= 5.85 )      7.34       8
#> ----------------------------------------------

# augment (classification case)
augment(tidy_class_rpart) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 3
#> Columns: 9
#> $ rule_nbr                    <int> 1, 2, 3
#> $ LHS                         <chr> "( Petal.Width < 1.7 ) & ( Petal.Length >=…
#> $ RHS                         <fct> versicolor, setosa, virginica
#> $ support                     <int> 28, 21, 27
#> $ confidence                  <dbl> 0.9666667, 0.9565217, 0.9310345
#> $ lift                        <dbl> 2.533333, 3.461698, 2.721485
#> $ augmented_stats__support    <dbl> 24, 29, 21
#> $ augmented_stats__confidence <dbl> 0.8333333, 1.0000000, 0.9523810
#> $ augmented_stats__lift       <dbl> 2.936508, 2.551724, 3.356009

# augment (regression case)
augment(tidy_regr_rpart) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 5
#> Columns: 7
#> $ rule_nbr                 <int> 1, 2, 3, 4, 5
#> $ LHS                      <chr> "( Petal.Length >= 4.25 ) & ( Petal.Length < …
#> $ RHS                      <dbl> 6.147826, 4.988235, 5.600000, 6.533333, 7.337…
#> $ support                  <int> 23, 17, 16, 12, 8
#> $ augmented_stats__support <dbl> 20, 27, 13, 9, 5
#> $ augmented_stats__IQR     <dbl> 0.60, 0.45, 0.60, 0.50, 0.40
#> $ augmented_stats__RMSE    <dbl> 0.5085832, 0.3548713, 0.4497863, 0.3294215, 0…

# party ----
pen = palmerpenguins::penguins %>%
  tidytable::drop_na(bill_length_mm)
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(pen), replace = TRUE)

model_class_party = partykit::ctree(species ~ ., data = pen[train_index, ])
tidy_class_party  = tidy(model_class_party) %>%
  set_validation_data(pen[!train_index, ], "species")
tidy_class_party
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: constparty
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#> 
#> 
#>   rule_nbr LHS                   RHS   support confidence  lift terminal_node_id
#>      <int> <chr>                 <fct>   <dbl>      <dbl> <dbl> <chr>           
#> 1        1 ( flipper_length_mm … Gent…      50      1      3.02 8               
#> 2        2 ( flipper_length_mm … Adel…      73      0.986  2.08 3               
#> 3        3 ( flipper_length_mm … Chin…      17      0.941  4.86 5               
#> 4        4 ( flipper_length_mm … Chin…       7      0.714  3.69 9               
#> 5        5 ( flipper_length_mm … Chin…      13      0.692  3.57 6               
#> ----------------------------------------------

model_regr_party =
  partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ])
tidy_regr_party  = tidy(model_regr_party) %>%
  set_validation_data(pen[!train_index, ], "bill_length_mm")
tidy_regr_party
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: constparty
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#> 
#> 
#>   rule_nbr LHS                          RHS support   IQR  RMSE terminal_node_id
#>      <int> <chr>                      <dbl>   <dbl> <dbl> <dbl> <chr>           
#> 1        1 ( species %in% c('Chinstr…  51.6      13  2     1.78 9               
#> 2        2 ( species %in% c('Adelie'…  37.3      41  2.80  1.83 3               
#> 3        3 ( species %in% c('Adelie'…  40.4      35  2.55  2.08 4               
#> 4        4 ( species %in% c('Chinstr…  49.0      29  2     2.12 8               
#> 5        5 ( species %in% c('Chinstr…  45.9      42  2.88  2.50 6               
#> ----------------------------------------------

# augment (classification case)
augment(tidy_class_party) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 5
#> Columns: 10
#> $ rule_nbr                    <int> 1, 2, 3, 4, 5
#> $ LHS                         <chr> "( flipper_length_mm > 205 ) & ( bill_dept…
#> $ RHS                         <fct> Gentoo, Adelie, Chinstrap, Chinstrap, Chin…
#> $ support                     <dbl> 50, 73, 17, 7, 13
#> $ confidence                  <dbl> 1.0000000, 0.9863014, 0.9411765, 0.7142857…
#> $ lift                        <dbl> 3.018868, 2.076424, 4.857685, 3.686636, 3.…
#> $ terminal_node_id            <chr> "8", "3", "5", "9", "6"
#> $ augmented_stats__support    <dbl> 69, 66, 19, 4, 24
#> $ augmented_stats__confidence <dbl> 1.0000000, 1.0000000, 0.8947368, 0.2500000…
#> $ augmented_stats__lift       <dbl> 2.6000000, 2.4266667, 2.1712281, 0.6066667…

# augment (regression case)
augment(tidy_regr_party) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 5
#> Columns: 10
#> $ rule_nbr                 <int> 1, 2, 3, 4, 5
#> $ LHS                      <chr> "( species %in% c('Chinstrap', 'Gentoo') ) & …
#> $ RHS                      <dbl> 51.59231, 37.29512, 40.37143, 49.04828, 45.87…
#> $ support                  <dbl> 13, 41, 35, 29, 42
#> $ IQR                      <dbl> 2.000, 2.800, 2.550, 2.000, 2.875
#> $ RMSE                     <dbl> 1.778704, 1.827694, 2.080875, 2.124183, 2.499…
#> $ terminal_node_id         <chr> "9", "3", "4", "8", "6"
#> $ augmented_stats__support <dbl> 16, 32, 39, 39, 51
#> $ augmented_stats__IQR     <dbl> 1.525, 2.900, 2.750, 3.100, 1.500
#> $ augmented_stats__RMSE    <dbl> 1.336547, 2.232729, 2.420163, 3.048512, 2.499…

# cubist ----
att         = modeldata::attrition
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE)
cols_att    = setdiff(colnames(att), c("MonthlyIncome", "Attrition"))

model_cubist = Cubist::cubist(x = att[train_index, cols_att],
                              y = att[train_index, "MonthlyIncome"]
                              )

tidy_cubist = tidy(model_cubist) %>%
  set_validation_data(att[!train_index, ], "MonthlyIncome")
tidy_cubist
#> ---- Rulelist --------------------------------
#> ▶ Keys: committee
#> ▶ Number of distinct keys: 1
#> ▶ Number of rules: 4
#> ▶ Model type: cubist
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#> 
#> 
#>   rule_nbr committee LHS                  RHS   support   mean   min   max error
#>      <int>     <int> <chr>                <chr>   <int>  <dbl> <dbl> <dbl> <dbl>
#> 1        1         1 ( JobLevel > 1 ) & … (455…      33  4436.  2272  5301  392.
#> 2        2         1 ( JobLevel <= 1 )    (110…     251  2789   1081  4968  563.
#> 3        3         1 ( JobRole %in% c('M… (299…      89 16714  11031 19999  761 
#> 4        4         1 ( JobLevel > 1 ) & … (-14…     334  6843.  2306 13973 1036.
#> ----------------------------------------------

augment(tidy_cubist) %>%
  tidytable::unnest(augmented_stats, names_sep = "__") %>%
  tidytable::glimpse()
#> Rows: 4
#> Columns: 12
#> $ rule_nbr                 <int> 1, 2, 3, 4
#> $ committee                <int> 1, 1, 1, 1
#> $ LHS                      <chr> "( JobLevel > 1 ) & ( TotalWorkingYears <= 5 …
#> $ RHS                      <chr> "(4559)", "(1108) + (874 * JobLevel) + (48 * …
#> $ support                  <int> 33, 251, 89, 334
#> $ mean                     <dbl> 4435.6, 2789.0, 16714.0, 6842.9
#> $ min                      <dbl> 2272, 1081, 11031, 2306
#> $ max                      <dbl> 5301, 4968, 19999, 13973
#> $ error                    <dbl> 391.7, 562.7, 761.0, 1035.9
#> $ augmented_stats__support <dbl> 24, 292, 93, 354
#> $ augmented_stats__IQR     <dbl> 439.25, 906.00, 4950.00, 3685.50
#> $ augmented_stats__RMSE    <dbl> 283.9729, 754.9632, 865.8137, 1446.1057