# =====================================================================
# update.missoNet.R - Three-Stage Update Procedure
# =====================================================================

#' Implements the three-stage estimation procedure for the missoNet model.
#' @param X Predictor matrix (n x p)
#' @param Y Response matrix with potential missing values (n x q)
#' @param lamTh Regularization parameter for Theta (precision matrix)
#' @param lamB Regularization parameter for Beta (coefficient matrix)
#' @param Beta.maxit Maximum iterations for Beta update
#' @param Beta.thr Convergence threshold for Beta
#' @param Theta.maxit Maximum iterations for Theta update
#' @param Theta.thr Convergence threshold for Theta
#' @param verbose Verbosity level (0, 1, or 2)
#' @param eps Minimum eigenvalue threshold for PSD projection
#' @param eta Step size reduction factor for backtracking
#' @param info Pre-computed information (optional)
#' @param info.update Update information (optional)
#' @param under.cv Whether running under cross-validation
#' @param init.obj Initialization object
#' @param B.init Initial Beta estimate (optional)
#' @return Updated Beta and Theta estimates
#' @noRd

update.missoNet <- function(X, Y, lamTh, lamB,
                            Beta.maxit = 1000, Beta.thr = 1e-4,
                            Theta.maxit = 1000, Theta.thr = 1e-4,
                            verbose = 1, eps = 1e-8, eta = 0.8,
                            info = NULL, info.update = NULL, 
                            under.cv = FALSE,
                            init.obj, B.init = NULL) {
  
  # ========================= Stage I: Initialize =========================
  
  if (is.null(info)) {
    n <- nrow(X); p <- ncol(X); q <- ncol(Y)
    
    Xs <- robust_scale(X, center = init.obj$mx, scale = init.obj$sdx)
    Ys <- robust_scale(Y, center = init.obj$my, scale = init.obj$sdy)
    Z <- Ys; Z[is.na(Z)] <- 0
    
    obs_prob <- init.obj$obs_prob
    rho.mat.1 <- matrix(obs_prob, nrow = p, ncol = q, byrow = TRUE)
    rho.mat.2 <- outer(obs_prob, obs_prob, `*`); diag(rho.mat.2) <- obs_prob
    
    # Build info structure
    info <- list()
    info$n <- n; info$p <- p; info$q <- q
    info$xtx <- make_positive_definite(crossprod(Xs))
    info$til.xty <- crossprod(Xs, Z) / rho.mat.1
    
    # ========================= Warm Start (Optional) =========================
    
    if (is.null(B.init)) {
      if (info$p * info$q < 1000) {
        # For small problems, start with ridge regression
        B.init <- solve(info$xtx + diag(0.01, info$p), info$til.xty)
      } else {
        B.init <- matrix(0, info$p, info$q)
      }
    }
    
    # If info is null, then info.update is expected to be null
    info.update <- list(B.init = B.init)
    
    E <- Ys - Xs %*% B.init
    info.update$residual.cov <- getResCov(E, n, rho.mat.2)  # getResCov returns a PSD matrix
    
    # Warm start optimization
    if (isTRUE(init.obj$warm.start)) {
      if (verbose == 2) {
        cat("\n  -------------- Pre-optimization ---------------\n")
        cat("  Iter | Objective | Rel.Change\n")
      }
      
      Beta <- info.update$B.init
      residual.cov <- info.update$residual.cov
      
      # Build penalty matrices
      lamB.mat <- lamB * init.obj$lamB.pf
      lamB.mat[lamB.mat == 0] <- 1e-12
      lamTh.mat <- lamTh * init.obj$lamTh.pf
      lamTh.mat[lamTh.mat == 0] <- 1e-12
      
      # Adaptive warm start parameters
      n.preopt <- min(50L, max(3L, ceiling(Beta.maxit * 0.05)))
      obj_history <- numeric(n.preopt)
      lik.old <- 1e12
      
      for (s in seq_len(n.preopt)) {
        # Update Theta
        Theta.out <- run_glasso(S = residual.cov, rho = lamTh.mat,
                                thr = min(0.001, Theta.thr),
                                maxIt = max(1000, Theta.maxit))
        Theta <- make_symmetric(Theta.out$wi)
        
        # Update Beta
        Beta.out <- updateBeta(Theta = Theta, B0 = Beta,
                               n = info$n, xtx = info$xtx,
                               xty = info$til.xty, lamB = lamB.mat,
                               eta = eta, tolin = min(0.001, Beta.thr),
                               maxitrin = max(1000, as.integer(Beta.maxit)))
        Beta <- Beta.out$Bhat
        
        # Update residual covariance
        E <- Ys - Xs %*% Beta
        residual.cov <- getResCov(E, n, rho.mat.2)
        
        # Compute objective
        lik.new <- Q_func(residual.cov, Theta) + 
          sum(abs(lamTh.mat * Theta)) + 
          sum(abs(lamB.mat * Beta))
        
        obj_history[s] <- lik.new
        rel_change <- abs(lik.new - lik.old) / max(1, abs(lik.old))
        
        if (verbose == 2) {
          cat(sprintf("  %3d  | %.6f | %.2e\n", s, lik.new, rel_change))
        }
        
        # Early stopping
        if (s > 2) {
          recent_changes <- diff(obj_history[max(1, s-2):s])
          if (all(abs(recent_changes) < min(0.001, Beta.thr * 100))) {
            if (verbose == 2) cat("  Early convergence detected.\n")
            break
          }
        }
        
        if (rel_change < min(0.0001, Beta.thr * 10)) break
        lik.old <- lik.new
      }
      
      info.update$B.init <- Beta
      info.update$residual.cov <- residual.cov
    }
  }
  
  # ========================= Main Optimization =========================
  
  if (verbose == 2) {
    cat("\n  -------------- Main Optimization --------------\n")
    cat(sprintf("  Lambda.beta: %.4f  Lambda.theta: %.4f\n", lamB, lamTh))
    cat("  Stage | Component      | Iterations | Backend\n")
    cat("  ------|----------------|------------|----------\n")
  }
  
  # Build final penalty matrices
  lamB.mat <- lamB * init.obj$lamB.pf
  lamB.mat[lamB.mat == 0] <- 1e-12
  lamTh.mat <- lamTh * init.obj$lamTh.pf
  lamTh.mat[lamTh.mat == 0] <- 1e-12
  
  # ========================= Stage II: Update Theta =========================
  
  Theta.out <- tryCatch(
    run_glasso(S = info.update$residual.cov, rho = lamTh.mat,
               thr = Theta.thr, maxIt = Theta.maxit),
    error = function(e) {
      warning("Glasso failed, using diagonal approximation: ", e$message)
      d <- diag(info.update$residual.cov)
      list(wi = diag(pmin(pmax(1/d, eps), 1/eps)), niter = 0L, backend = "diagonal_fallback")
    }
  )
  
  if (verbose == 2) {
    ni <- Theta.out$niter %||% -1L
    backend <- Theta.out$backend %||% "unknown"
    cat(sprintf("    II  | Theta          | %10d | %s\n", ni, backend))
  }
  
  Theta <- make_symmetric(Theta.out$wi)
  
  # ========================= Stage III: Update Beta =========================
  
  # Fast path 1: No regularization -> OLS solution
  if (lamB == 0) {
    jitter <- max(eps, 1e-8)
    XtX_reg <- info$xtx + diag(jitter, nrow = info$p)
    
    # Try Cholesky first
    Bhat <- tryCatch({
      R <- chol(XtX_reg)
      backsolve(R, forwardsolve(t(R), info$til.xty))
    }, error = function(e) {
      # Fallback to SVD for numerical stability
      svd_solve(XtX_reg, info$til.xty, eps)
    })
    
    it.final <- 0L
    converged <- TRUE
    method <- "OLS"
    
  } else {
    # Fast path 2: Check KKT conditions for zero solution
    G0 <- -(info$til.xty %*% Theta) / info$n
    
    if (all(abs(G0) <= lamB.mat * 1.0001)) {
      Bhat <- matrix(0, nrow = info$p, ncol = info$q)
      it.final <- 0L
      converged <- TRUE
      method <- "KKT_0"
      
    } else {
      # Standard FISTA optimization
      B.out <- tryCatch(
        updateBeta(Theta = Theta, B0 = info.update$B.init,
                   n = info$n, xtx = info$xtx,
                   xty = info$til.xty, lamB = lamB.mat,
                   eta = eta, tolin = Beta.thr,
                   maxitrin = as.integer(Beta.maxit)),
        error = function(e) {
          warning("Beta update failed, using previous estimate: ", e$message)
          list(Bhat = info.update$B.init, it.final = -1L, converged = FALSE)
        }
      )
      
      Bhat <- B.out$Bhat
      it.final <- B.out$it.final %||% -1L
      converged <- B.out$converged %||% (it.final < Beta.maxit)
      method <- "FISTA"
    }
  }
  
  if (verbose == 2) {
    cat(sprintf("   III  | Beta (%s) | %10d | -\n", 
                format(method, width = 7), it.final))
    if (isTRUE(converged)) {
      cat("\n  Optimization converged successfully.\n\n")
    } else {
      cat("\n  Warning: Optimization did not fully converge.\n\n")
    }
  }
  
  # ========================= Return Results =========================
  
  if (under.cv) {
    # For CV, only return Beta
    return(Bhat)
    
  } else {
    # Full output
    conv_theta <- if (!is.null(Theta.out$niter) && !is.na(Theta.out$niter)) {
      Theta.out$niter < Theta.maxit
    } else {
      TRUE
    }
    
    result <- list(
      Beta = Bhat,
      Theta = Theta,
      iterations = list(
        Beta = it.final,
        Theta = Theta.out$niter %||% -1L
      ),
      converged = list(
        Beta = converged,
        Theta = conv_theta
      ),
      methods = list(
        Beta = method,
        Theta = Theta.out$backend %||% "unknown"
      )
    )
    
    return(result)
  }
}


# ========================= Helper Functions =========================

# Null-coalescing operator
`%||%` <- function(x, y) if (is.null(x)) y else x

# Check convergence
check_convergence <- function(obj_old, obj_new, tol, iter, max_iter) {
  if (iter >= max_iter) return(list(converged = FALSE, reason = "max_iterations"))
  
  if (!is.finite(obj_old) || !is.finite(obj_new)) {
    return(list(converged = FALSE, reason = "non_finite"))
  }
  
  rel_change <- abs(obj_new - obj_old) / max(1, abs(obj_old))
  
  if (rel_change < tol) {
    return(list(converged = TRUE, reason = "tolerance", rel_change = rel_change))
  }
  
  return(list(converged = FALSE, reason = "in_progress", rel_change = rel_change))
}

# glasso backend selection
run_glasso <- function(S, rho, thr, maxIt) {
  S <- make_symmetric(S)
  
  # Fast path: diagonal-only solution
  off <- row(S) != col(S)
  if (all(abs(S[off]) <= rho[off] * 1.0001)) {
    d <- pmax(diag(S), 1e-8)
    Theta <- diag(1 / d, nrow = nrow(S))
    return(list(wi = Theta, niter = 0L, fastpath = TRUE, backend = "diagonal_fast"))
  }
  
  backends <- list(
    # glassoFast
    glassofast = function() {
      if (!requireNamespace("glassoFast", quietly = TRUE)) return(NULL)
      fit <- glassoFast::glassoFast(S = S, rho = rho, thr = thr,
                                    maxIt = as.integer(maxIt), trace = FALSE)
      list(wi = fit$wi, niter = fit$niter %||% -1L, fastpath = FALSE)
    },
    
    # glasso
    glasso = function() {
      if (!requireNamespace("glasso", quietly = TRUE)) return(NULL)
      fit <- glasso::glasso(s = S, rho = rho, thr = thr, 
                            maxit = as.integer(maxIt), 
                            penalize.diagonal = TRUE, trace = FALSE)
      list(wi = fit$wi, niter = fit$niter %||% -1L, fastpath = FALSE)
    }
    
    # QUIC - good for large problems
    # quic = function() {
    #   if (!requireNamespace("QUIC", quietly = TRUE)) return(NULL)
    #   fit <- QUIC::QUIC(S = S, rho = rho, maxIter = as.integer(maxIt), 
    #                     tol = thr, msg = 0)
    #   list(wi = fit$X, niter = fit$iter %||% -1L, fastpath = FALSE)
    # },
    
    # BigQUIC - for very large sparse problems
    # bigquic = function() {
    #   if (!requireNamespace("BigQuic", quietly = TRUE)) return(NULL)
    #   fit <- BigQuic::BigQuic(X = NULL, inputS = S, lambda = rho,
    #                           numthreads = 1, maxit = maxIt, epsilon = thr)
    #   list(wi = fit$precision_matrices[[1]], niter = fit$iterations %||% NA_integer_)
    # },
    
    # flare - alternative implementation
    # flare = function() {
    #   if (!requireNamespace("flare", quietly = TRUE)) return(NULL)
    #   fit <- flare::sugm(data = S, lambda = rho, max.ite = as.integer(maxIt), 
    #                      prec = thr, standardize = FALSE, verbose = FALSE)
    #   list(wi = fit$icov[[1]], niter = -1L, fastpath = FALSE)
    # }
  )
  
  # Try backends in order
  for (backend_name in names(backends)) {
    result <- tryCatch(
      backends[[backend_name]](),
      error = function(e) NULL
    )
    if (!is.null(result)) {
      result$backend <- backend_name
      return(result)
    }
  }
  
  stop("No graphical lasso backend found. Install one of: glassoFast, glasso")
}

# SVD-based solver for ill-conditioned systems
svd_solve <- function(A, b, tol = 1e-8) {
  svd_A <- svd(A)
  # Truncate small singular values
  d <- svd_A$d
  d[d < tol * max(d)] <- tol * max(d)
  
  # Compute pseudo-inverse solution
  svd_A$v %*% (crossprod(svd_A$u, b) / d)
}
