summaryrefslogtreecommitdiff
path: root/images/lsmr_local_tuning.R
diff options
context:
space:
mode:
Diffstat (limited to 'images/lsmr_local_tuning.R')
-rw-r--r--images/lsmr_local_tuning.R329
1 files changed, 329 insertions, 0 deletions
diff --git a/images/lsmr_local_tuning.R b/images/lsmr_local_tuning.R
new file mode 100644
index 0000000..8ad743a
--- /dev/null
+++ b/images/lsmr_local_tuning.R
@@ -0,0 +1,329 @@
+load (sprintf ("%s/data/lsmr_local_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR")))
+
+#' Construct a linear kernel
+#'
+#' @return a linear kernel
+#' @export
+linear_kernel <- function () {
+ ret <- list ()
+ class (ret) <- "linear_kernel"
+ ret
+}
+
+#' Construct a cosine kernel
+#'
+#' @return a cosine kernel
+#' @export
+cosine_kernel <- function () {
+ ret <- list ()
+ class (ret) <- "cosine_kernel"
+ ret
+}
+
+#' Construct a RBF kernel
+#'
+#' @param bandwidth the sigma parameter for the RBF...
+#' @param gamma ... or alternatively the gamma parameter
+#' @return an RBF kernel
+#' @export
+rbf_kernel <- function (bandwidth = NULL, gamma = NULL) {
+ stopifnot (!is.null (bandwidth) || !is.null (gamma))
+ if (is.null (gamma)) {
+ gamma <- 1 / (2 * bandwidth ^ 2)
+ }
+ ret <- list (gamma = gamma)
+ class (ret) <- "rbf_kernel"
+ ret
+}
+
+#' Construct a Laplacian matrix with binary relations
+#'
+#' @param kernel the kernel to compute base similarities
+#' @param quantile used to compute the threshold.
+#' @return a Laplacian matrix generator
+#' @export
+quantile_laplacian <- function (kernel = linear_kernel (), quantile = 0.95) {
+ ret <- list (kernel = kernel, q = quantile)
+ class (ret) <- "quantile_laplacian"
+ ret
+}
+
+#' Apply a kernel over two data matrices
+#'
+#' @param x the kernel to apply
+#' @param U the first data matrix
+#' @param V the second data matrix (may be missing)
+#' @return the kernel matrix
+#' @export
+cache <- function (x, U, V) {
+ UseMethod ("cache", x)
+}
+
+#' @method cache linear_kernel
+#' @export
+cache.linear_kernel <- function (x, U, V = NULL) {
+ if (is.null (V)) {
+ V <- U
+ }
+ tcrossprod (U, V)
+}
+
+#' @method cache cosine_kernel
+#' @export
+cache.cosine_kernel <- function (x, U, V = NULL) {
+ if (is.null (V)) {
+ V <- U
+ }
+ num <- tcrossprod (U, V)
+ nu <- sqrt (rowSums (U^2))
+ nv <- sqrt (rowSums (V^2))
+ denom <- tcrossprod (nu, nv)
+ ret <- num / denom
+ ret[denom == 0] <- 1
+ ret
+}
+
+pdist <- function (U, V = NULL) {
+ if (is.null (V)) {
+ V <- U
+ }
+ rsu <- as.matrix (rowSums (U^2), nrow (U), 1)
+ rsv <- as.matrix (rowSums (V^2), nrow (V), 1)
+ Du <- rsu[, array (1, nrow (V)), drop = FALSE]
+ Dv <- t (rsv[, array (1, nrow (U)), drop = FALSE])
+ D <- Du + Dv - 2 * tcrossprod (U, V)
+ D[D < 0] <- 0
+ D
+}
+
+#' @method cache rbf_kernel
+#' @export
+cache.rbf_kernel <- function (x, U, V = NULL) {
+ gamma <- x$gamma
+ exp (- gamma * pdist (U, V))
+}
+
+#' @method cache quantile_laplacian
+#' @export
+cache.quantile_laplacian <- function (x, U, V = NULL) {
+ if (!is.null (V)) {
+ stop ("Cannot apply the Laplacian matrix on two different data matrices")
+ }
+ K <- cache (x$kernel, U)
+ q <- stats::quantile (K[upper.tri (K)], x$q)
+ M <- matrix (0, nrow (K), ncol (K))
+ M[K < q] <- 0
+ M[K >= q] <- 1
+ D <- rowSums (M)
+ diag (D, nrow (K), ncol (K)) - M
+}
+
+#' Construct a RBF kernel fit for a validation dataset
+#'
+#' @param x a validation data matrix
+#' @param y the validation label matrix
+#' @return a RBF kernel
+#' @export
+tune_rbf_kernel <- function (x, y) {
+ B <- cache (cosine_kernel (), t (t (y)))
+ B[B < 0] <- 0
+ b <- t (t (c (B)))
+ D <- pdist (x)
+ candidates <- c (1e-4, 2e-4, 5e-4,
+ 1e-3, 2e-3, 5e-3,
+ 1e-2, 2e-2, 5e-2,
+ 1e-1, 2e-1, 5e-1,
+ 1e+0, 2e+0, 5e+0,
+ 1e+1, 2e+1, 5e+1,
+ 1e+2, 2e+2, 5e+2,
+ 1e+3, 2e+3, 5e+3,
+ 1e+4, 2e+4, 5e+4)
+ alignment <- sapply (candidates, function (gamma) {
+ K <- exp (-gamma * D)
+ k <- t (t (c (K)))
+ alignment <- cache.cosine_kernel (NULL, t (b), t (k))
+ alignment[1, 1]
+ })
+ rbf_kernel (gamma = candidates[which.max (alignment)])
+}
+
+
+#' Load the local tuning results.
+#'
+#' @return A table with the following columns: 'dataset', 'kernel',
+#' 'bandwidth', 's', 'semi', 'multi', 'armse_sssl', 'armse_semi',
+#' 'armse_multi', 'armse_both'.
+#' @export
+get_local_tuning_data <- function () {
+ local_tuning
+}
+
+#' Print the results for the local tuning.
+#'
+#' @return the data.
+#' @export
+print_tbl_comparison_local <- function () {
+ data <- get_local_tuning_data ()
+ `%>%` <- magrittr::`%>%`
+ number <- function (x) {
+ sapply (x, function (x) {
+ if (x <= 1) {
+ sprintf ("*%.3f*", x)
+ } else {
+ sprintf ("%.3f", x)
+ }
+ })
+ }
+ summaries <- (data
+ %>% dplyr::group_by (dataset)
+ %>% dplyr::summarize (median_sssl = median (armse_sssl),
+ mean = mean (armse_both),
+ median = median (armse_both),
+ q1 = quantile (armse_both, .25),
+ q3 = quantile (armse_both, .75),
+ min = min (armse_both),
+ max = max (armse_both))
+ %>% dplyr::mutate (relative_mean = mean / median_sssl,
+ relative_median = median / median_sssl,
+ relative_q1 = q1 / median_sssl,
+ relative_q3 = q3 / median_sssl,
+ relative_min = min / median_sssl,
+ relative_max = max / median_sssl)
+ %>% dplyr::mutate (`*Données*` = dataset,
+ `Moyenne` = number (relative_mean),
+ `Médiane` = number (relative_median),
+ `Q1` = number (relative_q1),
+ `Q3` = number (relative_q3),
+ `Meilleur` = number (relative_min),
+ `Pire` = number (relative_max))
+ %>% dplyr::select (`*Données*`, `Moyenne`, `Q1`, `Q3`, `Meilleur`, `Pire`)
+ %>% dplyr::arrange (`Meilleur`))
+ summaries
+}
+
+rescale_log <- function (value, min, max) {
+ log_min <- log (min)
+ log_max <- log (max)
+ log_value <- log_min + value * (log_max - log_min)
+ exp (log_value)
+}
+
+laps3l_decode_hyper <- function (max_s) {
+ min_bandwidth <- 0.1
+ max_bandwidth <- 300
+ min_semi <- 1e-08
+ max_semi <- 1
+ min_multi <- 1e-04
+ max_multi <- 10000
+ min_s <- 1
+ function (row) {
+ kernel <- NULL
+ row$kernel <- as.character (row$kernel)
+ if (row$kernel == "cosine") {
+ kernel <- cosine_kernel ()
+ }
+ else if (row$kernel == "linear") {
+ kernel <- linear_kernel ()
+ }
+ else {
+ stopifnot (row$kernel == "rbf")
+ bw <- rescale_log (row$bandwidth, min_bandwidth, max_bandwidth)
+ kernel <- rbf_kernel (bw)
+ }
+ list (kernel = kernel,
+ semi = rescale_log (row$semi, min_semi, max_semi),
+ multi = rescale_log (row$multi, min_multi, max_multi),
+ s = round (rescale_log (row$s, min_s, max_s)))
+ }
+}
+
+#' Print a local graph
+#'
+#' @param graph which graph to plot
+#' @return a ggplot object.
+#' @export
+print_local_graph <- function (graph = "atp1d") {
+ max_s <- NA
+ if (graph == "atp1d") {
+ max_s <- 262
+ } else if (graph == "atp7d") {
+ max_s <- 234
+ } else if (graph == "edm") {
+ max_s <- 121
+ } else if (graph == "enb") {
+ max_s <- 601
+ } else if (graph == "jura") {
+ max_s <- 281
+ } else if (graph == "oes10") {
+ max_s <- 314
+ } else if (graph == "oes97") {
+ max_s <- 257
+ } else if (graph == "osales") {
+ max_s <- 495
+ } else if (graph == "sarcossub") {
+ max_s <- 779
+ } else if (graph == "scpf") {
+ max_s <- 889
+ } else if (graph == "sf1") {
+ max_s <- 250
+ } else if (graph == "sf2") {
+ max_s <- 832
+ } else if (graph == "wq") {
+ max_s <- 827
+ } else {
+ stop ("Unknown dataset")
+ }
+ d <- laps3l_decode_hyper (max_s)
+ decode_row <- function (data, i) {
+ row <- data[i,]
+ row$kernel <- "linear"
+ row$s <- 0
+ items <- d (row)
+ data$semi[i] <- items$semi
+ data$multi[i] <- items$multi
+ data[i,]
+ }
+ decode <- function (data) {
+ do.call (rbind, lapply (seq_len (nrow (data)), function (i) decode_row (data, i)))
+ }
+ smooth <- function (data) {
+ X <- as.matrix (cbind (data$semi, data$multi))
+ y <- t (t (data$relative_armse))
+ D <- as.matrix (dist (X, diag = T, upper = T))
+ M <- 0 * D
+ M[D < 0.1] <- 1
+ sum <- rowSums (M)
+ M <- diag (1 / sum, nrow (X), nrow (X)) %*% M
+ data$relative_armse <- M %*% y
+ data
+ }
+ data <- get_local_tuning_data ()
+ `%>%` <- magrittr::`%>%`
+ armse_sssl_median <- median ((data
+ %>% dplyr::filter (dataset == graph)
+ %>% dplyr::select (armse_sssl))$armse_sssl, na.rm = TRUE)
+ relative_data <- (data
+ %>% dplyr::filter (dataset == graph)
+ %>% dplyr::mutate (relative_armse = armse_both / armse_sssl_median)
+ %>% dplyr::select (semi, multi, relative_armse))
+ averaged <- smooth (relative_data)
+ intp <- with (averaged,
+ akima::interp (x = semi,
+ y = multi,
+ z = relative_armse,
+ duplicate = "mean"))
+ values <- as.data.frame (as.matrix (intp$z))
+ colnames (values) <- intp$y
+ intp <- (tidyr::gather (cbind (values, semi = intp$x),
+ multi, armse, seq_len (ncol (intp$z)),
+ na.rm = TRUE)
+ %>% dplyr::mutate (multi = as.numeric (multi)))
+ interpolated <- (intp %>% decode ())
+ (ggplot2::ggplot (interpolated, ggplot2::aes (x = semi, y = multi, fill = armse))
+ + ggplot2::geom_tile ()
+ + ggplot2::scale_fill_gradient2 (midpoint = 1, name = "aRMSE\nrelative")
+ + ggplot2::xlab ("Régulariseur semi-supervisé $\\alpha$")
+ + ggplot2::ylab ("Régulariseur multi-label $\\beta$")
+ + ggplot2::scale_x_log10 ()
+ + ggplot2::scale_y_log10 ())
+}