require("bnlearn")

# Pasted from bnlearn::backend-score.R
arcs.to.be.added = function(amat, nodes, blacklist = NULL, whitelist = NULL,
                            arcs = TRUE) {
  
  .Call("hc_to_be_added",
        arcs = amat,
        blacklist = blacklist,
        whitelist = whitelist,
        nodes = nodes,
        convert = arcs,
        PACKAGE = "bnlearn")
  
}#ARCS.TO.BE.ADDED

gen.dataset.from.fitted.bn = function(bn.fitted, n) {
  
  data = list()
  nodes_to_process = names(bn.fitted)
  
  while (length(nodes_to_process) > 0)
    for (node in nodes_to_process) {
      
      parents = bn.fitted[[node]]$parents
      dims = names(margin.table(bn.fitted[[node]]$prob, 1))
      probs = bn.fitted[[node]]$prob
      
      # All parents must have been processed
      parents.ok = TRUE
      for (parent in parents)
        if(is.null(data[[parent]])) {
          parents.ok = FALSE
          next
        }
      if (!parents.ok)
        next
      
      nodes_to_process = setdiff(nodes_to_process, node)
      
      # Simplest case : no parents
      if (length(parents) == 0) {
        data[[node]] = sample(dims, n, prob = probs, replace = TRUE)
        next
      }
      
      # Fill with parent's values
      par_dims = NULL
      for (parent in parents) {
        if (length(par_dims) == 0)
          par_dims = names(margin.table(bn.fitted[[parent]]$prob, 1))
        else {
          tmp = c()
          for (dim1 in par_dims) {
            for (dim2 in names(margin.table(bn.fitted[[parent]]$prob, 1))) {
              tmp = c(tmp, paste(dim1, ":", dim2, sep=""))
            }
          }
          par_dims = tmp
        }
        if (is.null(data[[node]]))
          data[[node]] = data[[parent]]
        else
          data[[node]] = apply(cbind(data[[node]], data[[parent]]), 1, paste, collapse = ":")
      }
      
      # Transform parent's values
      k = 0
      for (dim in par_dims) {
        data[[node]][data[[node]] == dim] = sample(
              dims,
              length(which(data[[node]] == dim)),
              prob = probs[(k * length(dims) + 1):((k + 1) * length(dims))],
              replace = TRUE)
        k = k + 1
      }
    }
  
  for (node in names(bn.fitted)) {
    data[[node]] = factor(data[[node]], dimnames(bn.fitted[[node]]$prob)[[1]])
  }
  
  return(as.data.frame(data))
}#GEN.DATASET.FROM.FITTED.BN

ph1.measures = function(time, superstruct) {
  
  time = time["user.self"]
  names(time) = NULL
  
  tp = 0
  tn = 0
  fp = 0
  fn = 0
  
  done = vector()
  for (node in names(alarm$nodes)) {
    
    done = c(done, node)
    tnbr = setdiff(alarm$nodes[[node]]$nbr, done)
    rnbr = setdiff(superstruct$nodes[[node]]$nbr, done)
    
    fn = fn + length(tnbr)
    for(n in tnbr) {
      if(n %in% rnbr) {
        tp = tp + 1
        fn = fn - 1
      }
    }
    
    for(n in rnbr) {
      if(!(n %in% tnbr)) {
        fp = fp + 1
      }
    }
  }
  
  nbnodes = length(alarm$nodes)
  tn = nbnodes * (nbnodes - 1) - tp - fn - fp
  
  recall = ifelse(fn == 0, 1, tp / (tp + fn))
  precision = ifelse(fp == 0, 1, tp / (tp + fp))
  error = sqrt((1 - recall)^2 + (1 - precision)^2)
  
  specificity = ifelse(fp == 0, 1, tn / (tn + fp))
  fpr = 1 - specificity
  fnr = 1 - recall
  
  return(c(
    ph1.time = time,
    ph1.nbtests = superstruct$learning$ntests,
    ph1.tp = tp,
    ph1.tn = tn,
    ph1.fp = fp,
    ph1.fn = fn,
    ph1.recall = recall,
    ph1.precision = precision,
    ph1.error = error,
    ph1.specificity = specificity,
    ph1.fpr = fpr,
    ph1.fnr = fnr))
  
}#PH1.MEASURES

ph2.measures = function(time, dag) {
  
  time = time["user.self"]
  names(time) = NULL
  
  bde.test = score(dag, test, type = "bde", iss = 10)
  bde.train = score(dag, training, type = "bde", iss = 10)
  bic.test = score(dag, test, type = "bic", k = log(nrow(test))/2)
  bic.train = score(dag, training, type = "bic", k = log(nrow(training))/2)
  shd = shd(dag, alarm)
  
  return(c(
    ph2.time = time,
    ph2.nbscores = dag$learning$nscores,
    ph2.bde.test = bde.test,
    ph2.bde.train = bde.train,
    ph2.bic.test = bic.test,
    ph2.bic.train = bic.train,
    ph2.shd = shd))
  
}#PH2.MEASURES
