# rm(list=ls())
library(ape)
library(MASS)
library(phytools)
## Loading required package: maps
library(geiger)


CompareRates.multTraitEB<-function(phy,x,TraitCov=T,ms.err=NULL,ms.cov=NULL){

  build.chol<-function(b){
    c.mat<-matrix(0,nrow = p,ncol = p)
    c.mat[lower.tri(c.mat)]<-b[-1]
    c.mat[p,p]<-exp(b[1])
    c.mat[1,1]<-sqrt(sum((c.mat[p,])^2))
    if(p>2){
      for(i in 2:(p-1)){
        c.mat[i,i]<-ifelse((c.mat[1,1]^2-sum((c.mat[i,])^2))>0,sqrt(c.mat[1,1]^2-sum((c.mat[i,])^2)),0)
      }
    }
    return(c.mat)
  }

  x<-as.matrix(x)
  N<-nrow(x)
  p<-ncol(x)
  C<-vcv.phylo(phy)
  C<-C[rownames(x),rownames(x)]

  I<-diag(1,N)

  if (is.matrix(ms.err)){
    ms.err<-as.matrix(ms.err[rownames(x),])}
  if (is.matrix(ms.cov)){
    ms.cov<-as.matrix(ms.cov[rownames(x),])}

#  a.obs<-colSums(solve(C)) %*% x / sum(solve(C))
#  R.obs<-t(x-one%*%a.obs)%*%solve(C)%*%(x-one%*%a.obs)/N

  IIDcon<-function(trait=trait,mu=mu,sigma=sigma,C=C){
    z<- (trait-mu)/sigma
    eC<-eigen(C)
    D.n5<-diag(1/sqrt(eC$values))
    C.neg.5<-eC$vectors%*%D.n5%*%t(eC$vectors)
    return(C.neg.5%*%trait)
  }

  one<-matrix(1,N,1)
  a.obs.eb<-array(NA,p)
  R.obs.eb<-array(NA,c(p,p))
  IIDtrait<-array(NA,dim(x))
  r.array<-array(NA,p)
  for(Index in 1:p){
#    ?phylosig
#    ?fitContinuous
#    Index<-1
    r<-fitContinuous(phy=phy,dat=x[,Index],model="EB")$opt$a
    r.array[Index]<-r

    ### HERE FOR EB

    Sr<- (exp(r*C)-1)/r

    assign(paste("Sr",Index,sep=""),Sr)
    a.obs.eb[Index]<-colSums(solve(Sr)) %*% x[,Index] / sum(solve(Sr))
    R.obs.eb[Index,Index]<-t(x[,Index]-one%*%a.obs.eb[Index])%*%solve(Sr)%*%(x[,Index]-one%*%a.obs.eb[Index])/N
    IIDtrait[,Index]<-IIDcon(trait=x[,Index],mu=a.obs.eb[Index],sigma=R.obs.eb[Index,Index],C=Sr)
  }

  for(i in 1: (p-1)){
    for( j in (i+1):p){
      R.obs.eb[i,j]<-R.obs.eb[j,i]<-  cov(IIDtrait[,i],IIDtrait[,j])
    }
  }

  dim(a.obs.eb)
  a.obs.eb<-matrix(a.obs.eb,nrow=1,ncol=p)
  a.obs.eb
  dim(a.obs.eb)
  R.obs.eb

  D<-matrix(0,N*p,p)
  for(i in 1:(N*p)){
    for(j in 1:p){
      if((j-1)*N < i && i<=j*N){
        D[i,j]=1.0
      }
    }
  }
  y<-as.matrix(as.vector(x))


  if (TraitCov==F){R.obs.eb<-diag(diag(R.obs.eb),p)}


  Sr1
  Sr2
  Sr3
  r1<-r.array[1]
  r2<-r.array[2]
  r3<-r.array[3]
  #alpha12<-(alpha1+alpha2)/2
  #alpha12<-sqrt(alpha1*alpha2)
  #alpha12<-1/(1/alpha1+1/alpha2)



  ##################USE THIS ONE THROUOUT##############################
  r12<-0
  r13<-0
  r23<-0
  Sr12 <- (exp((r1+r2)/2*C)-1)/ (r1/2+r2/2)
  Sr13 <- (exp((r1+r3)/2*C)-1)/ (r1/2+r3/2)
  Sr23 <- (exp((r2+r3)/2*C)-1)/ (r2/2+r3/2)
  #####################################################################



  RkronSr<-rbind(cbind(R.obs.eb[1,1]*Sr1,R.obs.eb[1,2]*Sr12, R.obs.eb[1,3]*Sr13),
                 cbind(R.obs.eb[2,1]*Sr12,R.obs.eb[2,2]*Sr2, R.obs.eb[2,3]*Sr23),
                 cbind(R.obs.eb[3,1]*Sr13,R.obs.eb[3,2]*Sr23, R.obs.eb[3,3]*Sr3)
                 )
  RkronSr
  dim(RkronSr)

  LLik.obs.eb<-ifelse(is.matrix(ms.err)==TRUE,
                       -t(y-D%*%t(a.obs.eb))%*%ginv((RkronSr+diag(as.vector(ms.err))))%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant((RkronSr+ diag(as.vector(ms.err))))$modulus[1]/2,
                       -t(y-D%*%t(a.obs.eb))%*%ginv(RkronSr)%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant(RkronSr)$modulus[1]/2)

#  sigma.mn<-mean(diag(R.obs))
  sigma.mn.eb<-mean(diag(R.obs.eb))
  if(is.matrix(ms.err) && is.matrix(ms.cov)){
    within.spp<-cbind(ms.err,ms.cov)
    rc.label<-NULL
    for(i in 1:p){
      rc.label<-rbind(rc.label,c(i,i))
    }
    for(j in 2:p){
      if(i!=j&&i<j){
        rc.label<-rbind(rc.label,c(i,j))
      }
    }
    m.e<-NULL
    for(i in 1:p){
      temp<-NULL
      for(j in 1:p){
        for(k in 1:nrow(rc.label)){
          if(setequal(c(i,j),rc.label[k,])==T)
          {tmp<-cbind(tmp,diag(within.spp[,k]))}
        }
      }
      m.e<-rbind(m.e,tmp)
    }
  }

  R<-R.obs.eb
  #param<-c(sigma.mn.eb,alpha1,alpha2)
  #names(param)<-c("sigma","alpha1","alpha2")
  lik.covF.eb<-function(param){
    r1<-param["r1"]
    r2<-param["r2"]
    r3<-param["r3"]
    sigma<-param["sigma"]
    r12<-(r1+r2)/2
    r23<-(r2+r3)/2
    r13<-(r1+r3)/2
    ##########################################
    Sr12 <- (exp((r1+r2)/2*C)-1)/ (r1/2+r2/2)
    Sr13 <- (exp((r1+r3)/2*C)-1)/ (r1/2+r3/2)
    Sr23 <- (exp((r2+r3)/2*C)-1)/ (r2/2+r3/2)
    #########################################
    diag(R)<-sigma.mn.eb
    RkronSr<-rbind(cbind(R[1,1]*Sr1,R[1,2]*Sr12,R[1,3]*Sr13),
                   cbind(R[2,1]*Sr12,R[2,2]*Sr2,R[2,3]*Sr23),
                   cbind(R[3,1]*Sr13,R[3,2]*Sr23,R[3,3]*Sr3)
                   )
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs.eb))%*%ginv((RkronSr+m.e))%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant((RkronSr+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs.eb))%*%ginv(RkronSr)%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant(RkronSr)$modulus[1]/2)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  #param<-c(alpha1,alpha2,sigma.mn.eb,0)
  #names(param)<-c("alpha1","alpha2","sigma","R.offd")
  lik.covT.eb<-function(param){
    #param<-p0
    r1<-param["r1"]
    r2<-param["r2"]
    r3<-param["r3"]
    sigma<-param["sigma"]
    R.offd<-param["R.offd"]

    r12<-(r1+r2)/2
    r23<-(r2+r3)/2
    r13<-(r1+r3)/2
    ##########################################
    Sr12 <- (exp((r1+r2)/2*C)-1)/ (r1/2+r2/2)
    Sr13 <- (exp((r1+r3)/2*C)-1)/ (r1/2+r3/2)
    Sr23 <- (exp((r2+r3)/2*C)-1)/ (r2/2+r3/2)
    #########################################
    low.chol<-build.chol(c(sigma,R.offd))
    R<-low.chol%*%t(low.chol)
    RkronSr<-rbind(cbind(R[1,1]*Sr1,R[1,2]*Sr12,R[1,3]*Sr13),
                   cbind(R[2,1]*Sr12,R[2,2]*Sr2,R[2,3]*Sr23),
                   cbind(R[3,1]*Sr13,R[3,2]*Sr23,R[3,3]*Sr3)
    )
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs.eb))%*%ginv((RkronSr+m.e))%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant((RkronSr+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs.eb))%*%ginv(RkronSr)%*%(y-D%*%t(a.obs.eb))/2-N*p*log(2*pi)/2-determinant(RkronSr)$modulus[1]/2)
#    print(LLik)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  sigma.upper<-2*max(apply(x,2,sd))
  p0<-c(r1,r2,r3,sigma.mn.eb)
  names(p0)<-c("r1","r2","r3","sigma")
  if(TraitCov==F){model.eb<-
    optim(p0,fn=lik.covF.eb,method="L-BFGS-B",lower = c(1e-5,1e-5,1e-5,1e-5),upper=c(2,2,2,sigma.upper))
  }

  R.offd<-rep(0,(p*(p-1)/2))
  p0<-c(r1,r2,r3,sigma.mn.eb,0)
  names(p0)<-c("r1","r2","r3","sigma","R.offd")
  if(TraitCov==T){model1.eb<-
    optim(p0,fn=lik.covT.eb,method="L-BFGS-B",lower = c(1e-5,1e-5,1e-5,1e-5,1e-5),upper=c(2,2,2,sigma.upper,sigma.upper))
  }

  if(TraitCov==F){R.constr.eb<-diag(model.eb$par["sigma"],p)}
  if(TraitCov==T){
    chol.mat<-build.chol(model1.eb$par[c("sigma","R.offd")])
    R.constr.eb<-chol.mat%*%t(chol.mat)
  }

  if(model1.eb$convergence==0){
    message.eb<-"Optimization has converged."}else{
      message.eb<-"Optim may not have converrged.
  Consideer changing startt value or lower/upper limits."}

  LRT.eb<-(-2*((-model1.eb$value-LLik.obs.eb)))

  LRT.prob.eb<-pchisq(LRT.eb, (p-1),lower.tail = FALSE)

  AIC.obs.eb<- -2*LLik.obs.eb+2*p+2*p+2*p #(2p twice: 1x for rates, 1x for anc.states,1x for h)
  AIC.common.eb<--2*(-model1.eb$value)+2+2*p+2*p #(2*1:for 1 rate 2p for anc.states)

  return(
    list(
      Robs.eb=R.obs.eb,
      Rconstrained.eb=R.constr.eb,
      Lobs.eb=LLik.obs.eb,
      Lconstrained.eb=(-model1.eb$value),
      LRTest.eb=LRT.eb,
      Prob.eb=LRT.prob.eb,
      AICc.obs.eb=AIC.obs.eb,
      AICc.constrained.eb=AIC.common.eb,
      RkronSr=RkronSr,
      optimmessage.eb=message.eb
    )
  )
}

### Sample code
  phy<-rcoal(5)
  plot(phy)

  x<- matrix(c(rnorm(5,2,1),rnorm(5,0,0.5),rnorm(5,1,1.5)),ncol=3)
  rownames(x)<-phy$tip.label#LETTERS[1:N]
 x
##        [,1]       [,2]       [,3]
## t4 1.680277  0.1072781  1.8157956
## t3 2.870034 -0.3845481  1.6683304
## t2 1.686470 -0.1737591 -0.9353785
## t1 2.781093 -0.2330348  1.3164488
## t5 1.792314 -0.4206589  0.4633278
 CompareRates.multTraitEB(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)
## Warning in fitContinuous(phy = phy, dat = x[, Index], model = "EB"): 
## Parameter estimates appear at bounds:
##  a

## Warning in fitContinuous(phy = phy, dat = x[, Index], model = "EB"): 
## Parameter estimates appear at bounds:
##  a

## Warning in fitContinuous(phy = phy, dat = x[, Index], model = "EB"): 
## Parameter estimates appear at bounds:
##  a
## $Robs.eb
##             [,1]        [,2]       [,3]
## [1,]  0.65566447 -0.07170547  1.6193252
## [2,] -0.07170547  0.03551905 -0.1499272
## [3,]  1.61932525 -0.14992717  2.7805818
## 
## $Rconstrained.eb
##           [,1]      [,2]      [,3]
## [1,] 1.1823859 0.3283496 0.3283496
## [2,] 0.3283496 1.1823859 0.4066177
## [3,] 0.3283496 0.4066177 1.1823859
## 
## $Lobs.eb
## [1] -6.7881
## 
## $Lconstrained.eb
## [1] 1905.863
## 
## $LRTest.eb
## [1] -3825.302
## 
## $Prob.eb
## [1] 1
## 
## $AICc.obs.eb
## [1] 31.5762
## 
## $AICc.constrained.eb
## [1] -3797.726
## 
## $RkronSr
##            t4          t3          t2          t1          t5          t4
## t4  0.9347094  0.00000000  0.00000000  0.00000000  0.00000000 -0.10222269
## t3  0.0000000  0.93470944  0.74273680  0.70932889  0.39615036  0.00000000
## t2  0.0000000  0.74273680  0.93470944  0.70932889  0.39615036  0.00000000
## t1  0.0000000  0.70932889  0.70932889  0.93470944  0.39615036  0.00000000
## t5  0.0000000  0.39615036  0.39615036  0.39615036  0.93470944  0.00000000
## t4 -0.1022227  0.00000000  0.00000000  0.00000000  0.00000000  0.05063565
## t3  0.0000000 -0.10222269 -0.08122797 -0.07757438 -0.04332422  0.00000000
## t2  0.0000000 -0.08122797 -0.10222269 -0.07757438 -0.04332422  0.00000000
## t1  0.0000000 -0.07757438 -0.07757438 -0.10222269 -0.04332422  0.00000000
## t5  0.0000000 -0.04332422 -0.04332422 -0.04332422 -0.10222269  0.00000000
## t4  2.3084957  0.00000000  0.00000000  0.00000000  0.00000000 -0.21373484
## t3  0.0000000  2.30849568  1.83437187  1.75186277  0.97839111  0.00000000
## t2  0.0000000  1.83437187  2.30849568  1.75186277  0.97839111  0.00000000
## t1  0.0000000  1.75186277  1.75186277  2.30849568  0.97839111  0.00000000
## t5  0.0000000  0.97839111  0.97839111  0.97839111  2.30849568  0.00000000
##             t3          t2          t1          t5         t4          t3
## t4  0.00000000  0.00000000  0.00000000  0.00000000  2.3084957  0.00000000
## t3 -0.10222269 -0.08122797 -0.07757438 -0.04332422  0.0000000  2.30849568
## t2 -0.08122797 -0.10222269 -0.07757438 -0.04332422  0.0000000  1.83437187
## t1 -0.07757438 -0.07757438 -0.10222269 -0.04332422  0.0000000  1.75186277
## t5 -0.04332422 -0.04332422 -0.04332422 -0.10222269  0.0000000  0.97839111
## t4  0.00000000  0.00000000  0.00000000  0.00000000 -0.2137348  0.00000000
## t3  0.05063565  0.04023599  0.03842619  0.02146050  0.0000000 -0.21373484
## t2  0.04023599  0.05063565  0.03842619  0.02146050  0.0000000 -0.16983752
## t1  0.03842619  0.03842619  0.05063565  0.02146050  0.0000000 -0.16219831
## t5  0.02146050  0.02146050  0.02146050  0.05063565  0.0000000 -0.09058551
## t4  0.00000000  0.00000000  0.00000000  0.00000000  3.9639726  0.00000000
## t3 -0.21373484 -0.16983752 -0.16219831 -0.09058551  0.0000000  3.96397263
## t2 -0.16983752 -0.21373484 -0.16219831 -0.09058551  0.0000000  3.14984339
## t1 -0.16219831 -0.16219831 -0.21373484 -0.09058551  0.0000000  3.00816506
## t5 -0.09058551 -0.09058551 -0.09058551 -0.21373484  0.0000000  1.68001855
##             t2          t1          t5
## t4  0.00000000  0.00000000  0.00000000
## t3  1.83437187  1.75186277  0.97839111
## t2  2.30849568  1.75186277  0.97839111
## t1  1.75186277  2.30849568  0.97839111
## t5  0.97839111  0.97839111  2.30849568
## t4  0.00000000  0.00000000  0.00000000
## t3 -0.16983752 -0.16219831 -0.09058551
## t2 -0.21373484 -0.16219831 -0.09058551
## t1 -0.16219831 -0.21373484 -0.09058551
## t5 -0.09058551 -0.09058551 -0.21373484
## t4  0.00000000  0.00000000  0.00000000
## t3  3.14984339  3.00816506  1.68001855
## t2  3.96397263  3.00816506  1.68001855
## t1  3.00816506  3.96397263  1.68001855
## t5  1.68001855  1.68001855  3.96397263
## 
## $optimmessage.eb
## [1] "Optim may not have converrged.\n  Consideer changing startt value or lower/upper limits."
#####

# tree<-read.tree("http://tonyjhwueng.info/phymvrates/ple.nwk")
# plot(tree)
# tree$tip.label
# # #
# df<-read.csv("http://tonyjhwueng.info/phymvrates/Adams2012-SystBiolData.csv")
# head(df)
# # #
# spX<-strsplit(as.character(df$X),"_")
# spname<-array(NA,length(phy$tip.label))
# for(Index in 1:length(phy$tip.label)){
#   spname[Index]<-paste("Plethodon_", spX[[Index]][2],sep="")
# }
# spname
# df$X<-spname
#
# HeadLength<-df$HeadLength
# names(HeadLength)<-spname
# BodyWidth<-df$BodyWidth
# names(BodyWidth)<-spname
#
# HeadLength<-HeadLength[phy$tip.label]
# BodyWidth<-BodyWidth[phy$tip.label]
# x<-cbind(HeadLength,BodyWidth)
# CompareRates.multTrait(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)