summaryrefslogtreecommitdiff
path: root/images/laps3l_graph_code.R
diff options
context:
space:
mode:
Diffstat (limited to 'images/laps3l_graph_code.R')
-rw-r--r--images/laps3l_graph_code.R74
1 files changed, 74 insertions, 0 deletions
diff --git a/images/laps3l_graph_code.R b/images/laps3l_graph_code.R
new file mode 100644
index 0000000..06b01b0
--- /dev/null
+++ b/images/laps3l_graph_code.R
@@ -0,0 +1,74 @@
+`%>%` <- magrittr::`%>%`
+
+load (sprintf ("%s/data/laps3l_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))
+
+aggregate_tuning_raw <- function (datasets, algorithms) {
+ data <- tuning
+ aggregated <- (data
+ %>% dplyr::filter (dataset %in% datasets, algorithm %in% algorithms)
+ %>% dplyr::group_by (algorithm, labeled_data, dataset)
+ %>% dplyr::summarize (armse = mean (rmse),
+ srmse = sd (rmse),
+ amae = mean (mae),
+ smae = sd (mae),
+ arrse = mean (rrse),
+ srrse = sd (rrse),
+ arae = mean (rae),
+ srae = sd (rae))
+ %>% dplyr::ungroup ())
+ metrics <- (aggregated
+ %>% tidyr::pivot_longer (c (armse, srmse, amae, smae, arrse, srrse, arae, srae),
+ names_to = "metric",
+ values_to = "value")
+ %>% dplyr::filter (is.finite (value)))
+ annotations <- (metrics
+ %>% dplyr::group_by (labeled_data, dataset, metric)
+ %>% dplyr::arrange (value)
+ %>% dplyr::summarize (algorithm = algorithm, is_best = c (
+ TRUE,
+ rep (FALSE, dplyr::n () - 1)
+ ))
+ %>% dplyr::ungroup ())
+ (metrics
+ %>% dplyr::inner_join (annotations))
+}
+
+#' Plot the tuning results
+#' @param metric the metric to show: "rmse", "mae", "rrse", "rae"
+#' @export
+plot_tuning <- function (metric = "rmse", dataset = "wine") {
+ the_metric <- metric
+ the_dataset <- dataset
+ (aggregate_tuning_raw (dataset, c ("laps3l", "sssl", "laprls"))
+ %>% dplyr::select (algorithm, labeled_data, dataset, metric, value)
+ %>% dplyr::filter (metric %in% c (sprintf ("a%s", the_metric),
+ sprintf ("s%s", the_metric)),
+ dataset == the_dataset)
+ %>% tidyr::pivot_wider (id_cols = c (algorithm, labeled_data, dataset),
+ names_from = metric,
+ values_from = value)
+ %>% dplyr::rename (mean = sprintf ("a%s", the_metric),
+ sd = sprintf ("s%s", the_metric))
+ %>% dplyr::mutate (low = mean - sd, high = mean + sd)
+ %>% dplyr::mutate (`Données labellisées` = labeled_data,
+ Algorithme =
+ ifelse (algorithm == "sssl",
+ "SSSL",
+ ifelse (algorithm == "laprls",
+ "LapRLS",
+ "\\textbf{LapS3L}")),
+ Value = mean,
+ low = low,
+ high = high)
+ %>% ggplot2::ggplot (ggplot2::aes (x = `Données labellisées`,
+ y = Value,
+ ymin = low,
+ ymax = high,
+ linetype = Algorithme,
+ color = Algorithme,
+ fill = Algorithme))
+ + ggplot2::geom_line ()
+ + ggplot2::geom_ribbon (alpha = 0.2, size = 0)
+ + ggplot2::ylab (metric)
+ + ggplot2::ggtitle (sprintf ("jeu de données %s, métrique %s", dataset, metric)))
+}