augment
outputs a rulelist with an additional column named
augmented_stats
based on summary statistics calculated using attribute
validation_data
.
# S3 method for rulelist
augment(x, ...)
A rulelist
(expressions) To be send to tidytable::summarise for custom aggregations. See examples.
A rulelist with a new dataframe-column named augmented_stats
.
The dataframe-column augmented_stats
will have these columns
corresponding to the estimation_type
:
For regression
: support
, IQR
, RMSE
For classification
: support
, confidence
, lift
along with custom aggregations.
# Examples for augment ------------------------------------------------------
library("magrittr")
# C5 ----
att = modeldata::attrition
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE)
model_c5 = C50::C5.0(Attrition ~., data = att[train_index, ], rules = TRUE)
tidy_c5 =
model_c5 %>%
tidy() %>%
set_validation_data(att[!train_index, ], "Attrition")
tidy_c5
#> ---- Rulelist --------------------------------
#> ▶ Keys: trial_nbr
#> ▶ Number of distinct keys: 1
#> ▶ Number of rules: 23
#> ▶ Model type: C5
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr trial_nbr LHS RHS support confidence lift
#> <int> <int> <chr> <fct> <int> <dbl> <dbl>
#> 1 1 1 ( Age > 30 ) & ( DistanceF… No 69 0.986 1.2
#> 2 2 1 ( DistanceFromHome <= 12 )… No 149 0.960 1.1
#> 3 3 1 ( Department == 'Research_… No 211 0.953 1.1
#> 4 4 1 ( Age > 30 ) & ( DistanceF… No 249 0.948 1.1
#> 5 5 1 ( JobInvolvement %in% c('M… No 353 0.944 1.1
#> 6 6 1 ( OverTime == 'No' ) & ( S… No 263 0.943 1.1
#> 7 7 1 ( Education %in% c('Master… No 101 0.942 1.1
#> 8 8 1 ( OverTime == 'No' ) & ( R… No 95 0.938 1.1
#> 9 9 1 ( BusinessTravel %in% c('N… No 352 0.915 1.1
#> 10 10 1 ( Education %in% c('Below_… No 265 0.910 1.1
#> # ℹ 13 more rows
#> ----------------------------------------------
augment(tidy_c5) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 23
#> Columns: 10
#> $ rule_nbr <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,…
#> $ trial_nbr <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
#> $ LHS <chr> "( Age > 30 ) & ( DistanceFromHome <= 12 )…
#> $ RHS <fct> No, No, No, No, No, No, No, No, No, No, Ye…
#> $ support <int> 69, 149, 211, 249, 353, 263, 101, 95, 352,…
#> $ confidence <dbl> 0.9859155, 0.9603000, 0.9531000, 0.9482000…
#> $ lift <dbl> 1.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.…
#> $ augmented_stats__support <dbl> 77, 122, 245, 282, 376, 305, 84, 111, 390,…
#> $ augmented_stats__confidence <dbl> 0.9220779, 0.9098361, 0.9346939, 0.9113475…
#> $ augmented_stats__lift <dbl> 9.3667749, 1.0091812, 1.0367533, 1.0108577…
# augment with custom aggregator
augment(tidy_c5,output_counts = list(table(Attrition))) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 23
#> Columns: 11
#> $ rule_nbr <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, …
#> $ trial_nbr <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
#> $ LHS <chr> "( Age > 30 ) & ( DistanceFromHome <= 1…
#> $ RHS <fct> No, No, No, No, No, No, No, No, No, No,…
#> $ support <int> 69, 149, 211, 249, 353, 263, 101, 95, 3…
#> $ confidence <dbl> 0.9859155, 0.9603000, 0.9531000, 0.9482…
#> $ lift <dbl> 1.2, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1,…
#> $ augmented_stats__support <dbl> 77, 122, 245, 282, 376, 305, 84, 111, 3…
#> $ augmented_stats__confidence <dbl> 0.9220779, 0.9098361, 0.9346939, 0.9113…
#> $ augmented_stats__output_counts <list> <<table[2]>>, <<table[2]>>, <<table[2]…
#> $ augmented_stats__lift <dbl> 9.3667749, 1.0091812, 1.0367533, 1.0108…
# rpart ----
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(iris), replace = TRUE)
model_class_rpart = rpart::rpart(Species ~ ., data = iris[train_index, ])
tidy_class_rpart = tidy(model_class_rpart) %>%
set_validation_data(iris[!train_index, ], "Species")
tidy_class_rpart
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 3
#> ▶ Model type: rpart
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr LHS RHS support confidence lift
#> <int> <chr> <fct> <int> <dbl> <dbl>
#> 1 1 ( Petal.Width < 1.7 ) & ( Petal.Lengt… vers… 28 0.967 2.53
#> 2 2 ( Petal.Width < 1.7 ) & ( Petal.Lengt… seto… 21 0.957 3.46
#> 3 3 ( Petal.Width >= 1.7 ) virg… 27 0.931 2.72
#> ----------------------------------------------
model_regr_rpart = rpart::rpart(Sepal.Length ~ ., data = iris[train_index, ])
tidy_regr_rpart = tidy(model_regr_rpart) %>%
set_validation_data(iris[!train_index, ], "Sepal.Length")
tidy_regr_rpart
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: rpart
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr LHS RHS support
#> <int> <chr> <dbl> <int>
#> 1 1 ( Petal.Length >= 4.25 ) & ( Petal.Length < 5.85 ) & (… 6.15 23
#> 2 2 ( Petal.Length < 4.25 ) & ( Petal.Length < 1.65 ) 4.99 17
#> 3 3 ( Petal.Length < 4.25 ) & ( Petal.Length >= 1.65 ) 5.6 16
#> 4 4 ( Petal.Length >= 4.25 ) & ( Petal.Length < 5.85 ) & (… 6.53 12
#> 5 5 ( Petal.Length >= 4.25 ) & ( Petal.Length >= 5.85 ) 7.34 8
#> ----------------------------------------------
# augment (classification case)
augment(tidy_class_rpart) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 3
#> Columns: 9
#> $ rule_nbr <int> 1, 2, 3
#> $ LHS <chr> "( Petal.Width < 1.7 ) & ( Petal.Length >=…
#> $ RHS <fct> versicolor, setosa, virginica
#> $ support <int> 28, 21, 27
#> $ confidence <dbl> 0.9666667, 0.9565217, 0.9310345
#> $ lift <dbl> 2.533333, 3.461698, 2.721485
#> $ augmented_stats__support <dbl> 24, 29, 21
#> $ augmented_stats__confidence <dbl> 0.8333333, 1.0000000, 0.9523810
#> $ augmented_stats__lift <dbl> 2.936508, 2.551724, 3.356009
# augment (regression case)
augment(tidy_regr_rpart) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 5
#> Columns: 7
#> $ rule_nbr <int> 1, 2, 3, 4, 5
#> $ LHS <chr> "( Petal.Length >= 4.25 ) & ( Petal.Length < …
#> $ RHS <dbl> 6.147826, 4.988235, 5.600000, 6.533333, 7.337…
#> $ support <int> 23, 17, 16, 12, 8
#> $ augmented_stats__support <dbl> 20, 27, 13, 9, 5
#> $ augmented_stats__IQR <dbl> 0.60, 0.45, 0.60, 0.50, 0.40
#> $ augmented_stats__RMSE <dbl> 0.5085832, 0.3548713, 0.4497863, 0.3294215, 0…
# party ----
pen = palmerpenguins::penguins %>%
tidytable::drop_na(bill_length_mm)
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(pen), replace = TRUE)
model_class_party = partykit::ctree(species ~ ., data = pen[train_index, ])
tidy_class_party = tidy(model_class_party) %>%
set_validation_data(pen[!train_index, ], "species")
tidy_class_party
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: constparty
#> ▶ Estimation type: classification
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr LHS RHS support confidence lift terminal_node_id
#> <int> <chr> <fct> <dbl> <dbl> <dbl> <chr>
#> 1 1 ( flipper_length_mm … Gent… 50 1 3.02 8
#> 2 2 ( flipper_length_mm … Adel… 73 0.986 2.08 3
#> 3 3 ( flipper_length_mm … Chin… 17 0.941 4.86 5
#> 4 4 ( flipper_length_mm … Chin… 7 0.714 3.69 9
#> 5 5 ( flipper_length_mm … Chin… 13 0.692 3.57 6
#> ----------------------------------------------
model_regr_party =
partykit::ctree(bill_length_mm ~ ., data = pen[train_index, ])
tidy_regr_party = tidy(model_regr_party) %>%
set_validation_data(pen[!train_index, ], "bill_length_mm")
tidy_regr_party
#> ---- Rulelist --------------------------------
#> ▶ Keys: NULL
#> ▶ Number of rules: 5
#> ▶ Model type: constparty
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr LHS RHS support IQR RMSE terminal_node_id
#> <int> <chr> <dbl> <dbl> <dbl> <dbl> <chr>
#> 1 1 ( species %in% c('Chinstr… 51.6 13 2 1.78 9
#> 2 2 ( species %in% c('Adelie'… 37.3 41 2.80 1.83 3
#> 3 3 ( species %in% c('Adelie'… 40.4 35 2.55 2.08 4
#> 4 4 ( species %in% c('Chinstr… 49.0 29 2 2.12 8
#> 5 5 ( species %in% c('Chinstr… 45.9 42 2.88 2.50 6
#> ----------------------------------------------
# augment (classification case)
augment(tidy_class_party) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 5
#> Columns: 10
#> $ rule_nbr <int> 1, 2, 3, 4, 5
#> $ LHS <chr> "( flipper_length_mm > 205 ) & ( bill_dept…
#> $ RHS <fct> Gentoo, Adelie, Chinstrap, Chinstrap, Chin…
#> $ support <dbl> 50, 73, 17, 7, 13
#> $ confidence <dbl> 1.0000000, 0.9863014, 0.9411765, 0.7142857…
#> $ lift <dbl> 3.018868, 2.076424, 4.857685, 3.686636, 3.…
#> $ terminal_node_id <chr> "8", "3", "5", "9", "6"
#> $ augmented_stats__support <dbl> 69, 66, 19, 4, 24
#> $ augmented_stats__confidence <dbl> 1.0000000, 1.0000000, 0.8947368, 0.2500000…
#> $ augmented_stats__lift <dbl> 2.6000000, 2.4266667, 2.1712281, 0.6066667…
# augment (regression case)
augment(tidy_regr_party) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 5
#> Columns: 10
#> $ rule_nbr <int> 1, 2, 3, 4, 5
#> $ LHS <chr> "( species %in% c('Chinstrap', 'Gentoo') ) & …
#> $ RHS <dbl> 51.59231, 37.29512, 40.37143, 49.04828, 45.87…
#> $ support <dbl> 13, 41, 35, 29, 42
#> $ IQR <dbl> 2.000, 2.800, 2.550, 2.000, 2.875
#> $ RMSE <dbl> 1.778704, 1.827694, 2.080875, 2.124183, 2.499…
#> $ terminal_node_id <chr> "9", "3", "4", "8", "6"
#> $ augmented_stats__support <dbl> 16, 32, 39, 39, 51
#> $ augmented_stats__IQR <dbl> 1.525, 2.900, 2.750, 3.100, 1.500
#> $ augmented_stats__RMSE <dbl> 1.336547, 2.232729, 2.420163, 3.048512, 2.499…
# cubist ----
att = modeldata::attrition
set.seed(100)
train_index = sample(c(TRUE, FALSE), nrow(att), replace = TRUE)
cols_att = setdiff(colnames(att), c("MonthlyIncome", "Attrition"))
model_cubist = Cubist::cubist(x = att[train_index, cols_att],
y = att[train_index, "MonthlyIncome"]
)
tidy_cubist = tidy(model_cubist) %>%
set_validation_data(att[!train_index, ], "MonthlyIncome")
tidy_cubist
#> ---- Rulelist --------------------------------
#> ▶ Keys: committee
#> ▶ Number of distinct keys: 1
#> ▶ Number of rules: 4
#> ▶ Model type: cubist
#> ▶ Estimation type: regression
#> ▶ Is validation data set: TRUE
#>
#>
#> rule_nbr committee LHS RHS support mean min max error
#> <int> <int> <chr> <chr> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 1 1 ( JobLevel > 1 ) & … (455… 33 4436. 2272 5301 392.
#> 2 2 1 ( JobLevel <= 1 ) (110… 251 2789 1081 4968 563.
#> 3 3 1 ( JobRole %in% c('M… (299… 89 16714 11031 19999 761
#> 4 4 1 ( JobLevel > 1 ) & … (-14… 334 6843. 2306 13973 1036.
#> ----------------------------------------------
augment(tidy_cubist) %>%
tidytable::unnest(augmented_stats, names_sep = "__") %>%
tidytable::glimpse()
#> Rows: 4
#> Columns: 12
#> $ rule_nbr <int> 1, 2, 3, 4
#> $ committee <int> 1, 1, 1, 1
#> $ LHS <chr> "( JobLevel > 1 ) & ( TotalWorkingYears <= 5 …
#> $ RHS <chr> "(4559)", "(1108) + (874 * JobLevel) + (48 * …
#> $ support <int> 33, 251, 89, 334
#> $ mean <dbl> 4435.6, 2789.0, 16714.0, 6842.9
#> $ min <dbl> 2272, 1081, 11031, 2306
#> $ max <dbl> 5301, 4968, 19999, 13973
#> $ error <dbl> 391.7, 562.7, 761.0, 1035.9
#> $ augmented_stats__support <dbl> 24, 292, 93, 354
#> $ augmented_stats__IQR <dbl> 439.25, 906.00, 4950.00, 3685.50
#> $ augmented_stats__RMSE <dbl> 283.9729, 754.9632, 865.8137, 1446.1057