`%>%` <- magrittr::`%>%` load (sprintf ("%s/data/laps3l_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR"))) aggregate_tuning_raw <- function (datasets, algorithms) { data <- tuning aggregated <- (data %>% dplyr::filter (dataset %in% datasets, algorithm %in% algorithms) %>% dplyr::group_by (algorithm, labeled_data, dataset) %>% dplyr::summarize (armse = mean (rmse), srmse = sd (rmse), amae = mean (mae), smae = sd (mae), arrse = mean (rrse), srrse = sd (rrse), arae = mean (rae), srae = sd (rae)) %>% dplyr::ungroup ()) metrics <- (aggregated %>% tidyr::pivot_longer (c (armse, srmse, amae, smae, arrse, srrse, arae, srae), names_to = "metric", values_to = "value") %>% dplyr::filter (is.finite (value))) annotations <- (metrics %>% dplyr::group_by (labeled_data, dataset, metric) %>% dplyr::arrange (value) %>% dplyr::summarize (algorithm = algorithm, is_best = c ( TRUE, rep (FALSE, dplyr::n () - 1) )) %>% dplyr::ungroup ()) (metrics %>% dplyr::inner_join (annotations)) } #' Plot the tuning results #' @param metric the metric to show: "rmse", "mae", "rrse", "rae" #' @export plot_tuning <- function (metric = "rmse", dataset = "wine") { the_metric <- metric the_dataset <- dataset (aggregate_tuning_raw (dataset, c ("laps3l", "sssl", "laprls")) %>% dplyr::select (algorithm, labeled_data, dataset, metric, value) %>% dplyr::filter (metric %in% c (sprintf ("a%s", the_metric), sprintf ("s%s", the_metric)), dataset == the_dataset) %>% tidyr::pivot_wider (id_cols = c (algorithm, labeled_data, dataset), names_from = metric, values_from = value) %>% dplyr::rename (mean = sprintf ("a%s", the_metric), sd = sprintf ("s%s", the_metric)) %>% dplyr::mutate (low = mean - sd, high = mean + sd) %>% dplyr::mutate (`Données labellisées` = labeled_data, Algorithme = ifelse (algorithm == "sssl", "SSSL", ifelse (algorithm == "laprls", "LapRLS", "\\textbf{LapS3L}")), Value = mean, low = low, high = high) %>% ggplot2::ggplot (ggplot2::aes (x = `Données labellisées`, y = Value, ymin = low, ymax = high, linetype = Algorithme, color = Algorithme, fill = Algorithme)) + ggplot2::geom_line () + ggplot2::geom_ribbon (alpha = 0.2, size = 0) + ggplot2::ylab (metric) + ggplot2::ggtitle (sprintf ("jeu de données %s, métrique %s", dataset, metric))) }