1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
|
`%>%` <- magrittr::`%>%`
load (sprintf ("%s/data/laps3l_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))
aggregate_tuning_raw <- function (datasets, algorithms) {
data <- tuning
aggregated <- (data
%>% dplyr::filter (dataset %in% datasets, algorithm %in% algorithms)
%>% dplyr::group_by (algorithm, labeled_data, dataset)
%>% dplyr::summarize (armse = mean (rmse),
srmse = sd (rmse),
amae = mean (mae),
smae = sd (mae),
arrse = mean (rrse),
srrse = sd (rrse),
arae = mean (rae),
srae = sd (rae))
%>% dplyr::ungroup ())
metrics <- (aggregated
%>% tidyr::pivot_longer (c (armse, srmse, amae, smae, arrse, srrse, arae, srae),
names_to = "metric",
values_to = "value")
%>% dplyr::filter (is.finite (value)))
annotations <- (metrics
%>% dplyr::group_by (labeled_data, dataset, metric)
%>% dplyr::arrange (value)
%>% dplyr::summarize (algorithm = algorithm, is_best = c (
TRUE,
rep (FALSE, dplyr::n () - 1)
))
%>% dplyr::ungroup ())
(metrics
%>% dplyr::inner_join (annotations))
}
#' Plot the tuning results
#' @param metric the metric to show: "rmse", "mae", "rrse", "rae"
#' @export
plot_tuning <- function (metric = "rmse", dataset = "wine") {
the_metric <- metric
the_dataset <- dataset
(aggregate_tuning_raw (dataset, c ("laps3l", "sssl", "laprls"))
%>% dplyr::select (algorithm, labeled_data, dataset, metric, value)
%>% dplyr::filter (metric %in% c (sprintf ("a%s", the_metric),
sprintf ("s%s", the_metric)),
dataset == the_dataset)
%>% tidyr::pivot_wider (id_cols = c (algorithm, labeled_data, dataset),
names_from = metric,
values_from = value)
%>% dplyr::rename (mean = sprintf ("a%s", the_metric),
sd = sprintf ("s%s", the_metric))
%>% dplyr::mutate (low = mean - sd, high = mean + sd)
%>% dplyr::mutate (`Données labellisées` = labeled_data,
Algorithme =
ifelse (algorithm == "sssl",
"SSSL",
ifelse (algorithm == "laprls",
"LapRLS",
"\\textbf{LapS3L}")),
Value = mean,
low = low,
high = high)
%>% ggplot2::ggplot (ggplot2::aes (x = `Données labellisées`,
y = Value,
ymin = low,
ymax = high,
linetype = Algorithme,
color = Algorithme,
fill = Algorithme))
+ ggplot2::geom_line ()
+ ggplot2::geom_ribbon (alpha = 0.2, size = 0)
+ ggplot2::ylab (metric)
+ ggplot2::ggtitle (sprintf ("jeu de données %s, métrique %s", dataset, metric)))
}
|