summaryrefslogtreecommitdiff
path: root/images/rsms_2.R
diff options
context:
space:
mode:
Diffstat (limited to 'images/rsms_2.R')
-rw-r--r--images/rsms_2.R58
1 files changed, 58 insertions, 0 deletions
diff --git a/images/rsms_2.R b/images/rsms_2.R
new file mode 100644
index 0000000..8be2cf0
--- /dev/null
+++ b/images/rsms_2.R
@@ -0,0 +1,58 @@
+load (sprintf ("%s/data/rsms_test.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))
+
+Sys.setlocale ("LC_ALL", "fr_FR.UTF-8")
+
+#' Print the graphs for label selection
+#' @return the graph.
+#' @export
+labels_graph <- function () {
+ library ("magrittr")
+ data <- (test
+ %>% dplyr::filter (frac_labeled == 0.3,
+ frac_features == 0.3,
+ algorithm == "formulas",
+ dataset %in% c ("atp1d", "atp7d", "edm", "enb", "oes10", "oes97", "osales", "scpf", "sf1", "sf2", "wq"),
+ !is.na (error))
+ %>% dplyr::select (dataset, frac_labels, error)
+ %>% dplyr::group_by (dataset, frac_labels)
+ %>% dplyr::summarize (n = dplyr::n (),
+ mean = mean (error),
+ sd = sd (error),
+ min = min (error),
+ max = max (error),
+ median = median (error),
+ q1 = quantile (error, 0.25),
+ q3 = quantile (error, 0.75))
+ %>% dplyr::ungroup ())
+ arrange <- function (...) {
+ gridExtra::grid.arrange (..., layout_matrix = rbind (
+ c (1, 2, 3),
+ c (4, 5, 6),
+ c (7, 8, 9),
+ c (10, 11, 11)
+ ))
+ }
+ do.call (arrange, lapply ((data
+ %>% dplyr::select (dataset)
+ %>% dplyr::distinct ())$dataset, function (dataset_name) {
+ with_legend <- (data
+ %>% dplyr::filter (dataset == dataset_name)
+ %>% dplyr::mutate (ymin = mean - sd, ymax = max + sd)
+ %>% dplyr::select (`Labels sélectionnées` = frac_labels,
+ `aRMSE moyenne` = mean, ymin, ymax)
+ %>% ggplot2::ggplot (ggplot2::aes (x = `Labels sélectionnées`,
+ y = `aRMSE moyenne`))
+ + ggplot2::geom_line ()
+ + ggplot2::ggtitle (dataset_name)
+ + ggplot2::scale_x_continuous (labels = scales::percent))
+ if (dataset_name == "wq") {
+ with_legend
+ } else {
+ with_legend + ggplot2::theme (legend.position = "none")
+ }
+ }))
+}
+
+plot <- labels_graph ()
+filename <- Sys.getenv ("OUTPUT")
+ggplot2::ggsave (filename, plot, device = "svg", width = 6, height = 8)