# rm(list=ls())
library(ape)
library(MASS)
library(phytools)
## Loading required package: maps
CompareRates.multTraitID<-function(phy,x,TraitCov=T,ms.err=NULL,ms.cov=NULL){
  build.chol<-function(b){
    c.mat<-matrix(0,nrow = p,ncol = p)
    c.mat[lower.tri(c.mat)]<-b[-1]
    c.mat[p,p]<-exp(b[1])
    c.mat[1,1]<-sqrt(sum((c.mat[p,])^2))
    if(p>2){
      for(i in 2:(p-1)){
        c.mat[i,i]<-ifelse((c.mat[1,1]^2-sum((c.mat[i,])^2))>0,sqrt(c.mat[1,1]^2-sum((c.mat[i,])^2)),0)
      }
    }
    return(c.mat)
  }


  x<-as.matrix(x)
  N<-nrow(x)
  p<-ncol(x)
  # C<-vcv.phylo(phy)
  # C<-C[rownames(x),rownames(x)]

  I<-diag(1,N)
  C<-I

  if (is.matrix(ms.err)){
    ms.err<-as.matrix(ms.err[rownames(x),])}
  if (is.matrix(ms.cov)){
    ms.cov<-as.matrix(ms.cov[rownames(x),])}

  a.obs<-colSums(solve(C)) %*% x / sum(solve(C))
  one<-matrix(1,N,1)

  R.obs<-t(x-one%*%a.obs)%*%solve(C)%*%(x-one%*%a.obs)/N

  IIDcon<-function(trait=trait,mu=mu,sigma=sigma,C=C){
    z<- (trait-mu)/sigma
    eC<-eigen(C)
    D.n5<-diag(1/sqrt(eC$values))
    C.neg.5<-eC$vectors%*%D.n5%*%t(eC$vectors)
    return(C.neg.5%*%trait)
  }

  D<-matrix(0,N*p,p)
  for(i in 1:(N*p)){
    for(j in 1:p){
      if((j-1)*N < i && i<=j*N){
        D[i,j]=1.0
      }
    }
  }
  y<-as.matrix(as.vector(x))

  if (TraitCov==F){R.obs<-diag(diag(R.obs),p)}


  LLik.obs<-ifelse(is.matrix(ms.err)==TRUE,
                   -t(y-D%*%t(a.obs))%*%ginv((kronecker(R.obs,C)+diag(as.vector(ms.err))))%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant((kronecker(R.obs,C)+ diag(as.vector(ms.err))))$modulus[1]/2,
                   -t(y-D%*%t(a.obs))%*%ginv(kronecker(R.obs,C))%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant(kronecker(R.obs,C))$modulus[1]/2)

  sigma.mn<-mean(diag(R.obs))
  if(is.matrix(ms.err) && is.matrix(ms.cov)){
    within.spp<-cbind(ms.err,ms.cov)
    rc.label<-NULL
    for(i in 1:p){
      rc.label<-rbind(rc.label,c(i,i))
    }
    for(j in 2:p){
      if(i!=j&&i<j){
        rc.label<-rbind(rc.label,c(i,j))
      }
    }
    m.e<-NULL
    for(i in 1:p){
      temp<-NULL
      for(j in 1:p){
        for(k in 1:nrow(rc.label)){
          if(setequal(c(i,j),rc.label[k,])==T)
          {tmp<-cbind(tmp,diag(within.spp[,k]))}
        }
      }
      m.e<-rbind(m.e,tmp)
    }
  }

  lik.covF<-function(sigma){
    R<-R.obs
    diag(R)<-sigma
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs))%*%ginv((kronecker(R,C)+m.e))%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant((kronecker(R,C)+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs))%*%ginv(kronecker(R,C))%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant(kronecker(R,C))$modulus[1]/2)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }


  ################################
  #h12<-(h1+h2)/2# HERE IS AN ISSUE
  ################################
  #Sh12 <- h12^2*C + (1-h12^2)*I


  lik.covT<-function(sigma){
    low.chol<-build.chol(sigma)
    R<-low.chol%*%t(low.chol)
    LLik <- ifelse(is.matrix(ms.err)==TRUE,
                   -t(y-D%*%t(a.obs))%*%ginv(kronecker(R,C)+m.e)%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant(kronecker(R,C)+m.e)$modulus[1]/2,
                   -t(y-D%*%t(a.obs))%*%ginv(kronecker(R,C))%*%(y-D%*%t(a.obs))/2-N*p*log(2*pi)/2-determinant(kronecker(R,C))$modulus[1]/2
    )
    if(LLik==-Inf){LLik<--le+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  #TraitCov=F
  ## Optimize for no trait covariation
  if(TraitCov==F){model<-optim(sigma.mn,fn=lik.covF,method
                               ="L-BFGS-B",hessian = TRUE,lower = c(0.0))}
  sigma.upper<-2*max(apply(x,2,sd))


  #TraitCov=T
  ## Optimize with trait covariation
  R.offd<-rep(0,(p*(p-1)/2))
  if(TraitCov==T){model1<-
    #  optim(par=c(sigma.mn,R.offd),fn=lik.covT,method="L-BFGS-B",lower = c(0,0))
    optim(par=c(sigma.mn,R.offd),fn=lik.covT,method="L-BFGS-B",lower = c(0,0))
  }

  #sigma<-c(sigma.mn,R.offd)
  #### Assemble R.constrained
  if(TraitCov==F){R.constr<-diag(model$par,p)}

  if(TraitCov==T){
    chol.mat<-build.chol(model1$par)
    R.constr<-chol.mat%*%t(chol.mat)
  }

  if(model1$convergence==0){
    message<-"Optimization has converged."}else{
      message<-"Optim may not have converrged.
  Consideer changing startt value or lower/upper limits."}


  LRT<-(-2*((-model1$value-LLik.obs)))

  LRT.prob<-pchisq(LRT, (p-1),lower.tail = FALSE)

  AIC.obs<- -2*LLik.obs+2*p+2*p #(2p twice: 1x for rates, 1x for anc.states)
  AIC.common<--2*(-model1$value)+2+2*p #(2*1:for 1 rate 2p for anc.states)

  return(

    list(
      Robs=R.obs,
      Rconstrained=R.constr,
      Lobs=LLik.obs,
      Lconstrained=(-model1$value),
      LRTest=LRT,
      Prob=LRT.prob,
      AICc.obs=AIC.obs,
      AICc.constrained=AIC.common,
      optimmessage=message
    )
  )
}

 phy<-rcoal(5)
 plot(phy)

 x<- matrix(c(rnorm(5,2,1),rnorm(5,0,0.5),rnorm(5,1,1.5)),ncol=3)
 rownames(x)<-phy$tip.label#LETTERS[1:N]
 head(x)
##         [,1]        [,2]       [,3]
## t5 0.2122762 -0.56047734  1.5637595
## t1 1.9123808 -0.04181873  2.8498638
## t3 0.7808833 -0.19449493  0.4449976
## t2 3.7273135  0.42708806 -0.8072073
## t4 0.4835225 -0.41982362  1.0017717
 CompareRates.multTraitID(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)
## $Robs
##            [,1]       [,2]       [,3]
## [1,]  1.6620276  0.4323567 -0.7173888
## [2,]  0.4323567  0.1175395 -0.2099141
## [3,] -0.7173888 -0.2099141  1.4626568
## 
## $Rconstrained
##           [,1]      [,2]      [,3]
## [1,] 1.2388156 0.8241701 0.0000000
## [2,] 0.8241701 1.2388156 0.4060828
## [3,] 0.0000000 0.4060828 1.2388156
## 
## $Lobs
## [1] -9.453943
## 
## $Lconstrained
## [1] -20.86455
## 
## $LRTest
## [1] 22.82121
## 
## $Prob
## [1] 1.10774e-05
## 
## $AICc.obs
## [1] 30.90789
## 
## $AICc.constrained
## [1] 49.72909
## 
## $optimmessage
## [1] "Optimization has converged."
# tree<-read.tree("http://tonyjhwueng.info/phyrates/ple.nwk")
# plot(tree)
#
# tree$tip.label
#
# df<-read.csv("http://tonyjhwueng.info/phyrates/Adams2012-SystBiolData.csv")
# head(df)
#
# spX<-strsplit(as.character(df$X),"_")
# spname<-array(NA,length(tree$tip.label))
# for(Index in 1:length(tree$tip.label)){
#   spname[Index]<-paste("Plethodon_", spX[[Index]][2],sep="")
# }
# spname
# df$X<-spname
#
# HeadLength<-df$HeadLength
# names(HeadLength)<-spname
# BodyWidth<-df$BodyWidth
# names(BodyWidth)<-spname
#
# HeadLength<-HeadLength[tree$tip.label]
# BodyWidth<-BodyWidth[tree$tip.label]
# x<-cbind(HeadLength,BodyWidth)
# CompareRates.multTrait(phy=tree,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)