summaryrefslogtreecommitdiff
path: root/images/rsms_4.R
diff options
context:
space:
mode:
Diffstat (limited to 'images/rsms_4.R')
-rw-r--r--images/rsms_4.R42
1 files changed, 42 insertions, 0 deletions
diff --git a/images/rsms_4.R b/images/rsms_4.R
new file mode 100644
index 0000000..94045d2
--- /dev/null
+++ b/images/rsms_4.R
@@ -0,0 +1,42 @@
+load (sprintf ("%s/data/rsms_protocol_4.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))
+
+Sys.setlocale ("LC_ALL", "fr_FR.UTF-8")
+
+#' Print the convergence graphs
+#' @return the graphs
+#' @export
+convergence_graph <- function () {
+ library ("magrittr")
+ data <- (protocol_4
+ %>% dplyr::group_by (dataset, maxiter)
+ %>% dplyr::summarize (mloss = mean (loss),
+ sloss = sd (loss))
+ %>% dplyr::mutate (dataset,
+ `Nombre d'itérations` = maxiter,
+ `Coût` = mloss,
+ mini = mloss - sloss,
+ maxi = mloss + sloss))
+ arrange <- function (...) {
+ gridExtra::grid.arrange (..., layout_matrix = rbind (
+ c (1, 1, 2, 2, 3, 3),
+ c (4, 4, 5, 5, 6, 6),
+ c (7, 7, 8, 8, 9, 9),
+ c (10, 10, 10, 11, 11, 11)
+ ))
+ }
+ do.call (arrange, lapply (c ("atp1d", "atp7d", "edm", "enb", "oes10", "oes97", "osales", "scpf", "sf1", "sf2", "wq"), function (dataset_name) {
+ (data
+ %>% dplyr::filter (dataset == dataset_name)
+ %>% ggplot2::ggplot (ggplot2::aes (x = `Nombre d'itérations`,
+ y = `Coût`,
+ ymin = mini,
+ ymax = maxi))
+ + ggplot2::geom_line ()
+ + ggplot2::geom_ribbon (alpha = 0.2)
+ + ggplot2::ggtitle (dataset_name))
+ }))
+}
+
+plot <- convergence_graph ()
+filename <- Sys.getenv ("OUTPUT")
+ggplot2::ggsave (filename, plot, device = "svg", width = 6, height = 8)