diff options
Diffstat (limited to 'images/lsmr_local_tuning.R')
-rw-r--r-- | images/lsmr_local_tuning.R | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/images/lsmr_local_tuning.R b/images/lsmr_local_tuning.R new file mode 100644 index 0000000..8ad743a --- /dev/null +++ b/images/lsmr_local_tuning.R @@ -0,0 +1,329 @@ +load (sprintf ("%s/data/lsmr_local_tuning.Rdata", Sys.getenv ("ABS_TOP_SRCDIR"))) + +#' Construct a linear kernel +#' +#' @return a linear kernel +#' @export +linear_kernel <- function () { + ret <- list () + class (ret) <- "linear_kernel" + ret +} + +#' Construct a cosine kernel +#' +#' @return a cosine kernel +#' @export +cosine_kernel <- function () { + ret <- list () + class (ret) <- "cosine_kernel" + ret +} + +#' Construct a RBF kernel +#' +#' @param bandwidth the sigma parameter for the RBF... +#' @param gamma ... or alternatively the gamma parameter +#' @return an RBF kernel +#' @export +rbf_kernel <- function (bandwidth = NULL, gamma = NULL) { + stopifnot (!is.null (bandwidth) || !is.null (gamma)) + if (is.null (gamma)) { + gamma <- 1 / (2 * bandwidth ^ 2) + } + ret <- list (gamma = gamma) + class (ret) <- "rbf_kernel" + ret +} + +#' Construct a Laplacian matrix with binary relations +#' +#' @param kernel the kernel to compute base similarities +#' @param quantile used to compute the threshold. +#' @return a Laplacian matrix generator +#' @export +quantile_laplacian <- function (kernel = linear_kernel (), quantile = 0.95) { + ret <- list (kernel = kernel, q = quantile) + class (ret) <- "quantile_laplacian" + ret +} + +#' Apply a kernel over two data matrices +#' +#' @param x the kernel to apply +#' @param U the first data matrix +#' @param V the second data matrix (may be missing) +#' @return the kernel matrix +#' @export +cache <- function (x, U, V) { + UseMethod ("cache", x) +} + +#' @method cache linear_kernel +#' @export +cache.linear_kernel <- function (x, U, V = NULL) { + if (is.null (V)) { + V <- U + } + tcrossprod (U, V) +} + +#' @method cache cosine_kernel +#' @export +cache.cosine_kernel <- function (x, U, V = NULL) { + if (is.null (V)) { + V <- U + } + num <- tcrossprod (U, V) + nu <- sqrt (rowSums (U^2)) + nv <- sqrt (rowSums (V^2)) + denom <- tcrossprod (nu, nv) + ret <- num / denom + ret[denom == 0] <- 1 + ret +} + +pdist <- function (U, V = NULL) { + if (is.null (V)) { + V <- U + } + rsu <- as.matrix (rowSums (U^2), nrow (U), 1) + rsv <- as.matrix (rowSums (V^2), nrow (V), 1) + Du <- rsu[, array (1, nrow (V)), drop = FALSE] + Dv <- t (rsv[, array (1, nrow (U)), drop = FALSE]) + D <- Du + Dv - 2 * tcrossprod (U, V) + D[D < 0] <- 0 + D +} + +#' @method cache rbf_kernel +#' @export +cache.rbf_kernel <- function (x, U, V = NULL) { + gamma <- x$gamma + exp (- gamma * pdist (U, V)) +} + +#' @method cache quantile_laplacian +#' @export +cache.quantile_laplacian <- function (x, U, V = NULL) { + if (!is.null (V)) { + stop ("Cannot apply the Laplacian matrix on two different data matrices") + } + K <- cache (x$kernel, U) + q <- stats::quantile (K[upper.tri (K)], x$q) + M <- matrix (0, nrow (K), ncol (K)) + M[K < q] <- 0 + M[K >= q] <- 1 + D <- rowSums (M) + diag (D, nrow (K), ncol (K)) - M +} + +#' Construct a RBF kernel fit for a validation dataset +#' +#' @param x a validation data matrix +#' @param y the validation label matrix +#' @return a RBF kernel +#' @export +tune_rbf_kernel <- function (x, y) { + B <- cache (cosine_kernel (), t (t (y))) + B[B < 0] <- 0 + b <- t (t (c (B))) + D <- pdist (x) + candidates <- c (1e-4, 2e-4, 5e-4, + 1e-3, 2e-3, 5e-3, + 1e-2, 2e-2, 5e-2, + 1e-1, 2e-1, 5e-1, + 1e+0, 2e+0, 5e+0, + 1e+1, 2e+1, 5e+1, + 1e+2, 2e+2, 5e+2, + 1e+3, 2e+3, 5e+3, + 1e+4, 2e+4, 5e+4) + alignment <- sapply (candidates, function (gamma) { + K <- exp (-gamma * D) + k <- t (t (c (K))) + alignment <- cache.cosine_kernel (NULL, t (b), t (k)) + alignment[1, 1] + }) + rbf_kernel (gamma = candidates[which.max (alignment)]) +} + + +#' Load the local tuning results. +#' +#' @return A table with the following columns: 'dataset', 'kernel', +#' 'bandwidth', 's', 'semi', 'multi', 'armse_sssl', 'armse_semi', +#' 'armse_multi', 'armse_both'. +#' @export +get_local_tuning_data <- function () { + local_tuning +} + +#' Print the results for the local tuning. +#' +#' @return the data. +#' @export +print_tbl_comparison_local <- function () { + data <- get_local_tuning_data () + `%>%` <- magrittr::`%>%` + number <- function (x) { + sapply (x, function (x) { + if (x <= 1) { + sprintf ("*%.3f*", x) + } else { + sprintf ("%.3f", x) + } + }) + } + summaries <- (data + %>% dplyr::group_by (dataset) + %>% dplyr::summarize (median_sssl = median (armse_sssl), + mean = mean (armse_both), + median = median (armse_both), + q1 = quantile (armse_both, .25), + q3 = quantile (armse_both, .75), + min = min (armse_both), + max = max (armse_both)) + %>% dplyr::mutate (relative_mean = mean / median_sssl, + relative_median = median / median_sssl, + relative_q1 = q1 / median_sssl, + relative_q3 = q3 / median_sssl, + relative_min = min / median_sssl, + relative_max = max / median_sssl) + %>% dplyr::mutate (`*Données*` = dataset, + `Moyenne` = number (relative_mean), + `Médiane` = number (relative_median), + `Q1` = number (relative_q1), + `Q3` = number (relative_q3), + `Meilleur` = number (relative_min), + `Pire` = number (relative_max)) + %>% dplyr::select (`*Données*`, `Moyenne`, `Q1`, `Q3`, `Meilleur`, `Pire`) + %>% dplyr::arrange (`Meilleur`)) + summaries +} + +rescale_log <- function (value, min, max) { + log_min <- log (min) + log_max <- log (max) + log_value <- log_min + value * (log_max - log_min) + exp (log_value) +} + +laps3l_decode_hyper <- function (max_s) { + min_bandwidth <- 0.1 + max_bandwidth <- 300 + min_semi <- 1e-08 + max_semi <- 1 + min_multi <- 1e-04 + max_multi <- 10000 + min_s <- 1 + function (row) { + kernel <- NULL + row$kernel <- as.character (row$kernel) + if (row$kernel == "cosine") { + kernel <- cosine_kernel () + } + else if (row$kernel == "linear") { + kernel <- linear_kernel () + } + else { + stopifnot (row$kernel == "rbf") + bw <- rescale_log (row$bandwidth, min_bandwidth, max_bandwidth) + kernel <- rbf_kernel (bw) + } + list (kernel = kernel, + semi = rescale_log (row$semi, min_semi, max_semi), + multi = rescale_log (row$multi, min_multi, max_multi), + s = round (rescale_log (row$s, min_s, max_s))) + } +} + +#' Print a local graph +#' +#' @param graph which graph to plot +#' @return a ggplot object. +#' @export +print_local_graph <- function (graph = "atp1d") { + max_s <- NA + if (graph == "atp1d") { + max_s <- 262 + } else if (graph == "atp7d") { + max_s <- 234 + } else if (graph == "edm") { + max_s <- 121 + } else if (graph == "enb") { + max_s <- 601 + } else if (graph == "jura") { + max_s <- 281 + } else if (graph == "oes10") { + max_s <- 314 + } else if (graph == "oes97") { + max_s <- 257 + } else if (graph == "osales") { + max_s <- 495 + } else if (graph == "sarcossub") { + max_s <- 779 + } else if (graph == "scpf") { + max_s <- 889 + } else if (graph == "sf1") { + max_s <- 250 + } else if (graph == "sf2") { + max_s <- 832 + } else if (graph == "wq") { + max_s <- 827 + } else { + stop ("Unknown dataset") + } + d <- laps3l_decode_hyper (max_s) + decode_row <- function (data, i) { + row <- data[i,] + row$kernel <- "linear" + row$s <- 0 + items <- d (row) + data$semi[i] <- items$semi + data$multi[i] <- items$multi + data[i,] + } + decode <- function (data) { + do.call (rbind, lapply (seq_len (nrow (data)), function (i) decode_row (data, i))) + } + smooth <- function (data) { + X <- as.matrix (cbind (data$semi, data$multi)) + y <- t (t (data$relative_armse)) + D <- as.matrix (dist (X, diag = T, upper = T)) + M <- 0 * D + M[D < 0.1] <- 1 + sum <- rowSums (M) + M <- diag (1 / sum, nrow (X), nrow (X)) %*% M + data$relative_armse <- M %*% y + data + } + data <- get_local_tuning_data () + `%>%` <- magrittr::`%>%` + armse_sssl_median <- median ((data + %>% dplyr::filter (dataset == graph) + %>% dplyr::select (armse_sssl))$armse_sssl, na.rm = TRUE) + relative_data <- (data + %>% dplyr::filter (dataset == graph) + %>% dplyr::mutate (relative_armse = armse_both / armse_sssl_median) + %>% dplyr::select (semi, multi, relative_armse)) + averaged <- smooth (relative_data) + intp <- with (averaged, + akima::interp (x = semi, + y = multi, + z = relative_armse, + duplicate = "mean")) + values <- as.data.frame (as.matrix (intp$z)) + colnames (values) <- intp$y + intp <- (tidyr::gather (cbind (values, semi = intp$x), + multi, armse, seq_len (ncol (intp$z)), + na.rm = TRUE) + %>% dplyr::mutate (multi = as.numeric (multi))) + interpolated <- (intp %>% decode ()) + (ggplot2::ggplot (interpolated, ggplot2::aes (x = semi, y = multi, fill = armse)) + + ggplot2::geom_tile () + + ggplot2::scale_fill_gradient2 (midpoint = 1, name = "aRMSE\nrelative") + + ggplot2::xlab ("Régulariseur semi-supervisé $\\alpha$") + + ggplot2::ylab ("Régulariseur multi-label $\\beta$") + + ggplot2::scale_x_log10 () + + ggplot2::scale_y_log10 ()) +} |