summaryrefslogtreecommitdiff
path: root/images/laps3l_graph_code.R
blob: 06b01b05b77f1a5e95da4d033782020b489a58a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
`%>%` <- magrittr::`%>%`

load (sprintf ("%s/data/laps3l_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))

aggregate_tuning_raw <- function (datasets, algorithms) {
    data <- tuning
    aggregated <- (data
        %>% dplyr::filter (dataset %in% datasets, algorithm %in% algorithms)
        %>% dplyr::group_by (algorithm, labeled_data, dataset)
        %>% dplyr::summarize (armse = mean (rmse),
                              srmse = sd (rmse),
                              amae = mean (mae),
                              smae = sd (mae),
                              arrse = mean (rrse),
                              srrse = sd (rrse),
                              arae = mean (rae),
                              srae = sd (rae))
        %>% dplyr::ungroup ())
    metrics <- (aggregated
        %>% tidyr::pivot_longer (c (armse, srmse, amae, smae, arrse, srrse, arae, srae),
                                 names_to = "metric",
                                 values_to = "value")
        %>% dplyr::filter (is.finite (value)))
    annotations <- (metrics
        %>% dplyr::group_by (labeled_data, dataset, metric)
        %>% dplyr::arrange (value)
        %>% dplyr::summarize (algorithm = algorithm, is_best = c (
                                                         TRUE,
                                                         rep (FALSE, dplyr::n () - 1)
                                                     ))
        %>% dplyr::ungroup ())
    (metrics
        %>% dplyr::inner_join (annotations))
}

#' Plot the tuning results
#' @param metric the metric to show: "rmse", "mae", "rrse", "rae"
#' @export
plot_tuning <- function (metric = "rmse", dataset = "wine") {
    the_metric <- metric
    the_dataset <- dataset
    (aggregate_tuning_raw (dataset, c ("laps3l", "sssl", "laprls"))
        %>% dplyr::select (algorithm, labeled_data, dataset, metric, value)
        %>% dplyr::filter (metric %in% c (sprintf ("a%s", the_metric),
                                          sprintf ("s%s", the_metric)),
                           dataset == the_dataset)
        %>% tidyr::pivot_wider (id_cols = c (algorithm, labeled_data, dataset),
                                names_from = metric,
                                values_from = value)
        %>% dplyr::rename (mean = sprintf ("a%s", the_metric),
                           sd = sprintf ("s%s", the_metric))
        %>% dplyr::mutate (low = mean - sd, high = mean + sd)
        %>% dplyr::mutate (`Données labellisées` = labeled_data,
                           Algorithme =
                               ifelse (algorithm == "sssl",
                                       "SSSL",
                                       ifelse (algorithm == "laprls",
                                               "LapRLS",
                                               "\\textbf{LapS3L}")),
                           Value = mean,
                           low = low,
                           high = high)
        %>% ggplot2::ggplot (ggplot2::aes (x = `Données labellisées`,
                                           y = Value,
                                           ymin = low,
                                           ymax = high,
                                           linetype = Algorithme,
                                           color = Algorithme,
                                           fill = Algorithme))
        + ggplot2::geom_line ()
        + ggplot2::geom_ribbon (alpha = 0.2, size = 0)
        + ggplot2::ylab (metric)
        + ggplot2::ggtitle (sprintf ("jeu de données %s, métrique %s", dataset, metric)))
}