Skip to contents

lim1 <- s_curve_obj$s_curve_umap_scaled_obj$lim1
lim2 <- s_curve_obj$s_curve_umap_scaled_obj$lim2
r2 <- diff(lim2)/diff(lim1)

umap_scaled <- s_curve_obj$s_curve_umap_scaled_obj$scaled_nldr

## To initialize number of bins along the x-axis
bin1_vec <- 5:sqrt(NROW(s_curve_noise_training)/r2)

error_df_all <- data.frame(matrix(nrow = 0, ncol = 0))

for (xbins in bin1_vec) {

  hb_obj <- calc_bins_y(bin1 = xbins, r2 = r2, q = 0.1)

  bin2 <- hb_obj$bin2
  a1 <- hb_obj$a1
  a2 <- hb_obj$a2

  scurve_model <- fit_highd_model(
    highd_data = s_curve_noise_training,
    nldr_data = umap_scaled,
    bin1 = xbins,
    r2 = r2,
    q = 0.1,
    is_bin_centroid = TRUE
  )

  df_bin_centroids_scurve <- scurve_model$df_bin_centroids
  df_bin_scurve <- scurve_model$df_bin

  ## Compute error
  error_df <- glance(
    model_2d = df_bin_centroids_scurve,
    model_highd = df_bin_scurve,
    highd_data = s_curve_noise_training) |>
    mutate(bin1 = xbins,
           bin2 = bin2,
           b = bin1 * bin2,
           b_non_empty = NROW(df_bin_centroids_scurve),
           method = "UMAP",
           a1 = round(a1, 2),
           a2 = round(a2, 2))

  error_df_all <- bind_rows(error_df_all, error_df)

}

error_df_all <- error_df_all |>
  mutate(a1 = round(a1, 2)) |>
  filter(bin1 >= 5) |>
  group_by(a1) |>
  filter(MSE == min(MSE)) |>
  ungroup()

error_df_all |>
  arrange(-a1) |>
  head(5)
#> # A tibble: 5 × 9
#>   Error    MSE  bin1  bin2     b b_non_empty method    a1    a2
#>   <dbl>  <dbl> <int> <dbl> <dbl>       <int> <chr>  <dbl> <dbl>
#> 1 2534. 0.203      5     6    30          22 UMAP    0.25  0.21
#> 2 2262. 0.159      6     7    42          27 UMAP    0.21  0.18
#> 3 1944. 0.112      7     8    56          37 UMAP    0.18  0.15
#> 4 1755. 0.0897     8     9    72          44 UMAP    0.15  0.13
#> 5 1579. 0.0702     9    10    90          53 UMAP    0.14  0.12
ggplot(error_df_all,
         aes(x = a1,
             y = MSE)) +
    geom_point(size = 0.8) +
    geom_line(linewidth = 0.3) +
    ylab("MSE") +
    xlab(expression(paste("binwidth (", a[1], ")"))) +
    theme_minimal() +
    theme(panel.border = element_rect(fill = 'transparent'),
          plot.title = element_text(size = 12, hjust = 0.5, vjust = -0.5),
          axis.ticks.x = element_line(),
          axis.ticks.y = element_line(),
          legend.position = "none",
          axis.text.x = element_text(size = 7),
          axis.text.y = element_text(size = 7),
          axis.title.x = element_text(size = 7),
          axis.title.y = element_text(size = 7))

MSE Vs binw idths.