summaryrefslogtreecommitdiff
path: root/images/rsms_3.R
blob: 52a77e815f8fe84ccd6304fa665dd32f85e6d85d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
load (sprintf ("%s/data/rsms_test.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))

Sys.setlocale ("LC_ALL", "fr_FR.UTF-8")

#' Print the graphs for the interest of label selection
#' @return the graph.
#' @export
with_labels_graph <- function () {
    library ("magrittr")
    lower_number <- tibble::tibble (
                                dataset =
                                    c ("atp1d",
                                       "atp7d",
                                       "edm",
                                       "enb",
                                       "oes10",
                                       "oes97",
                                       "osales",
                                       "scpf",
                                       "sf1",
                                       "sf2",
                                       "wq"),
                                frac_labeled_restricted =
                                    c (0.8,
                                       0.8,
                                       0.6,
                                       0.6,
                                       0.8,
                                       0.8,
                                       0.8,
                                       0.6,
                                       0.6,
                                       0.6,
                                       0.8))
    data <- (test
        %>% dplyr::filter (frac_labeled == 0.3,
                           algorithm == "formulas",
                           dataset %in% c ("atp1d", "atp7d", "edm", "enb", "oes10", "oes97", "osales", "scpf", "sf1", "sf2", "wq"),
                           !is.na (error))
        %>% dplyr::inner_join (lower_number)
        %>% dplyr::mutate (full = (frac_labels == 1),
                           restricted = (frac_labels == frac_labeled_restricted))
        %>% dplyr::filter (full | restricted)
        %>% dplyr::mutate (algorithm = ifelse (full, "RSMS (tous labels)", "RSMS (restreint)"))
        %>% dplyr::select (dataset, frac_features, algorithm, error)
        %>% dplyr::group_by (dataset, frac_features, algorithm)
        %>% dplyr::summarize (n = dplyr::n (),
                              mean = mean (error),
                              sd = sd (error),
                              min = min (error),
                              max = max (error),
                              median = median (error),
                              q1 = quantile (error, 0.25),
                              q3 = quantile (error, 0.75))
        %>% dplyr::ungroup ())
    arrange <- function (...) {
        gridExtra::grid.arrange (..., layout_matrix = rbind (
                                          c (1, 2, 3),
                                          c (4, 5, 6),
                                          c (7, 8, 9),
                                          c (10, 11, 11)
                                      ))
    }
    do.call (arrange, lapply ((data
        %>% dplyr::select (dataset)
        %>% dplyr::distinct ())$dataset, function (dataset_name) {
            with_legend <- (data
                %>% dplyr::filter (dataset == dataset_name)
                %>% dplyr::mutate (ymin = mean - sd, ymax = max + sd)
                %>% dplyr::select (Algorithme = algorithm,
                                   `Variables` = frac_features,
                                   `aRMSE moyenne` = mean, ymin, ymax)
                %>% ggplot2::ggplot (ggplot2::aes (x = `Variables`,
                                                   y = `aRMSE moyenne`,
                                                   color = Algorithme,
                                                   linetype = Algorithme))
                + ggplot2::geom_line ()
                + ggplot2::ggtitle (dataset_name)
                + ggplot2::scale_x_continuous (labels = scales::percent)
                + ggplot2::scale_color_manual (limits = c ("RSMS (restreint)", "RSMS (tous labels)"),
                                               values = c ("black", "#e69f00", "#56b4e9", "#009e73"))
                + ggplot2::scale_linetype_manual (limits = c ("RSMS (restreint)", "RSMS (tous labels)"),
                                                  values = c ("solid", "dashed", "longdash", "dotdash")))
            if (dataset_name == "wq") {
                with_legend
            } else {
                with_legend + ggplot2::theme (legend.position = "none")
            }
        }))
}

plot <- with_labels_graph ()
filename <- Sys.getenv ("OUTPUT")
ggplot2::ggsave (filename, plot, device = "svg", width = 6, height = 8)