diff options
Diffstat (limited to 'images/laps3l_graph_code.R')
-rw-r--r-- | images/laps3l_graph_code.R | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/images/laps3l_graph_code.R b/images/laps3l_graph_code.R new file mode 100644 index 0000000..06b01b0 --- /dev/null +++ b/images/laps3l_graph_code.R @@ -0,0 +1,74 @@ +`%>%` <- magrittr::`%>%` + +load (sprintf ("%s/data/laps3l_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR"))) + +aggregate_tuning_raw <- function (datasets, algorithms) { + data <- tuning + aggregated <- (data + %>% dplyr::filter (dataset %in% datasets, algorithm %in% algorithms) + %>% dplyr::group_by (algorithm, labeled_data, dataset) + %>% dplyr::summarize (armse = mean (rmse), + srmse = sd (rmse), + amae = mean (mae), + smae = sd (mae), + arrse = mean (rrse), + srrse = sd (rrse), + arae = mean (rae), + srae = sd (rae)) + %>% dplyr::ungroup ()) + metrics <- (aggregated + %>% tidyr::pivot_longer (c (armse, srmse, amae, smae, arrse, srrse, arae, srae), + names_to = "metric", + values_to = "value") + %>% dplyr::filter (is.finite (value))) + annotations <- (metrics + %>% dplyr::group_by (labeled_data, dataset, metric) + %>% dplyr::arrange (value) + %>% dplyr::summarize (algorithm = algorithm, is_best = c ( + TRUE, + rep (FALSE, dplyr::n () - 1) + )) + %>% dplyr::ungroup ()) + (metrics + %>% dplyr::inner_join (annotations)) +} + +#' Plot the tuning results +#' @param metric the metric to show: "rmse", "mae", "rrse", "rae" +#' @export +plot_tuning <- function (metric = "rmse", dataset = "wine") { + the_metric <- metric + the_dataset <- dataset + (aggregate_tuning_raw (dataset, c ("laps3l", "sssl", "laprls")) + %>% dplyr::select (algorithm, labeled_data, dataset, metric, value) + %>% dplyr::filter (metric %in% c (sprintf ("a%s", the_metric), + sprintf ("s%s", the_metric)), + dataset == the_dataset) + %>% tidyr::pivot_wider (id_cols = c (algorithm, labeled_data, dataset), + names_from = metric, + values_from = value) + %>% dplyr::rename (mean = sprintf ("a%s", the_metric), + sd = sprintf ("s%s", the_metric)) + %>% dplyr::mutate (low = mean - sd, high = mean + sd) + %>% dplyr::mutate (`Données labellisées` = labeled_data, + Algorithme = + ifelse (algorithm == "sssl", + "SSSL", + ifelse (algorithm == "laprls", + "LapRLS", + "\\textbf{LapS3L}")), + Value = mean, + low = low, + high = high) + %>% ggplot2::ggplot (ggplot2::aes (x = `Données labellisées`, + y = Value, + ymin = low, + ymax = high, + linetype = Algorithme, + color = Algorithme, + fill = Algorithme)) + + ggplot2::geom_line () + + ggplot2::geom_ribbon (alpha = 0.2, size = 0) + + ggplot2::ylab (metric) + + ggplot2::ggtitle (sprintf ("jeu de données %s, métrique %s", dataset, metric))) +} |