# rm(list=ls())
library(ape)
library(MASS)
library(phytools)
## Loading required package: maps
library(geiger)


CompareRates.multTraitOU<-function(phy,x,TraitCov=T,ms.err=NULL,ms.cov=NULL){
  build.chol<-function(b){
    c.mat<-matrix(0,nrow = p,ncol = p)
    c.mat[lower.tri(c.mat)]<-b[-1]
    c.mat[p,p]<-exp(b[1])
    c.mat[1,1]<-sqrt(sum((c.mat[p,])^2))
    if(p>2){
      for(i in 2:(p-1)){
        c.mat[i,i]<-ifelse((c.mat[1,1]^2-sum((c.mat[i,])^2))>0,sqrt(c.mat[1,1]^2-sum((c.mat[i,])^2)),0)
      }
    }
    return(c.mat)
  }

  x<-as.matrix(x)
  N<-nrow(x)
  p<-ncol(x)
  C<-vcv.phylo(phy)
  C<-C[rownames(x),rownames(x)]

  I<-diag(1,N)

  if (is.matrix(ms.err)){
    ms.err<-as.matrix(ms.err[rownames(x),])}
  if (is.matrix(ms.cov)){
    ms.cov<-as.matrix(ms.cov[rownames(x),])}

#  a.obs<-colSums(solve(C)) %*% x / sum(solve(C))
#  R.obs<-t(x-one%*%a.obs)%*%solve(C)%*%(x-one%*%a.obs)/N

  IIDcon<-function(trait=trait,mu=mu,sigma=sigma,C=C){
    z<- (trait-mu)/sigma
    eC<-eigen(C)
    D.n5<-diag(1/sqrt(eC$values))
    C.neg.5<-eC$vectors%*%D.n5%*%t(eC$vectors)
    return(C.neg.5%*%trait)
  }

  one<-matrix(1,N,1)
  a.obs.ou<-array(NA,p)
  R.obs.ou<-array(NA,c(p,p))
  IIDtrait<-array(NA,dim(x))
  alpha.array<-array(NA,p)
  for(Index in 1:p){
#    ?phylosig
#    ?fitContinuous
  #  Index<-1
    alpha<-fitContinuous(phy=phy,dat=x[,Index],model="OU")$opt$alpha
    alpha.array[Index]<-alpha
    Sa<-(exp(-2*alpha*(max(C)-C)))*(1-exp(-2*alpha*C))/2/alpha
    assign(paste("Sa",Index,sep=""),Sa)
    a.obs.ou[Index]<-colSums(solve(Sa)) %*% x[,Index] / sum(solve(Sa))
    R.obs.ou[Index,Index]<-t(x[,Index]-one%*%a.obs.ou[Index])%*%solve(Sa)%*%(x[,Index]-one%*%a.obs.ou[Index])/N
    IIDtrait[,Index]<-IIDcon(trait=x[,Index],mu=a.obs.ou[Index],sigma=R.obs.ou[Index,Index],C=Sa)
  }


  for(i in 1: (p-1)){
    for( j in (i+1):p){
      R.obs.ou[i,j]<-R.obs.ou[j,i]<-  cov(IIDtrait[,i],IIDtrait[,j])
    }
  }

  dim(a.obs.ou)
  a.obs.ou<-matrix(a.obs.ou,nrow=1,ncol=p)
  a.obs.ou
  dim(a.obs.ou)
  R.obs.ou

  D<-matrix(0,N*p,p)
  for(i in 1:(N*p)){
    for(j in 1:p){
      if((j-1)*N < i && i<=j*N){
        D[i,j]=1.0
      }
    }
  }
  y<-as.matrix(as.vector(x))


  if (TraitCov==F){R.obs.ou<-diag(diag(R.obs.ou),p)}


  Sa1
  Sa2
  Sa3
  alpha1<-alpha.array[1]
  alpha2<-alpha.array[2]
  alpha3<-alpha.array[3]
  ################################
  # NEED TO REFER THEIS
  #alpha12<-(alpha1+alpha2)/2
  #alpha12<-sqrt(alpha1*alpha2)
  #alpha12<-1/(1/alpha1+1/alpha2)

  Sa12 <-(exp(-(alpha1+alpha2)*(max(C)-C)))*(1-exp(-(alpha1+alpha2)*C))/(alpha1+alpha2)
  Sa13 <-(exp(-(alpha1+alpha3)*(max(C)-C)))*(1-exp(-(alpha1+alpha3)*C))/(alpha1+alpha3)
  Sa23 <-(exp(-(alpha2+alpha3)*(max(C)-C)))*(1-exp(-(alpha2+alpha3)*C))/(alpha2+alpha3)
  ################################


  RkronSa<-rbind(cbind(R.obs.ou[1,1]*Sa1,R.obs.ou[1,2]*Sa12,R.obs.ou[1,3]*Sa13),
                 cbind(R.obs.ou[2,1]*Sa12,R.obs.ou[2,2]*Sa2,R.obs.ou[2,3]*Sa23),
                 cbind(R.obs.ou[3,1]*Sa13,R.obs.ou[3,2]*Sa23,R.obs.ou[3,3]*Sa3))
  RkronSa


  LLik.obs.ou<-ifelse(is.matrix(ms.err)==TRUE,
                       -t(y-D%*%t(a.obs.ou))%*%ginv((RkronSa+diag(as.vector(ms.err))))%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant((RkronSa+ diag(as.vector(ms.err))))$modulus[1]/2,
                       -t(y-D%*%t(a.obs.ou))%*%ginv(RkronSa)%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant(RkronSa)$modulus[1]/2)

#  sigma.mn<-mean(diag(R.obs))
  sigma.mn.ou<-mean(diag(R.obs.ou))
  if(is.matrix(ms.err) && is.matrix(ms.cov)){
    within.spp<-cbind(ms.err,ms.cov)
    rc.label<-NULL
    for(i in 1:p){
      rc.label<-rbind(rc.label,c(i,i))
    }
    for(j in 2:p){
      if(i!=j&&i<j){
        rc.label<-rbind(rc.label,c(i,j))
      }
    }
    m.e<-NULL
    for(i in 1:p){
      temp<-NULL
      for(j in 1:p){
        for(k in 1:nrow(rc.label)){
          if(setequal(c(i,j),rc.label[k,])==T)
          {tmp<-cbind(tmp,diag(within.spp[,k]))}
        }
      }
      m.e<-rbind(m.e,tmp)
    }
  }

  R<-R.obs.ou
  #param<-c(sigma.mn.ou,alpha1,alpha2)
  #names(param)<-c("sigma","alpha1","alpha2")
  lik.covF.ou<-function(param){
    alpha1<-param["alpha1"]
    alpha2<-param["alpha2"]
    alpha3<-param["alpha3"]
    sigma<-param["sigma"]
    alpha12<-(alpha1+alpha2)/2
    alpha23<-(alpha2+alpha3)/2
    alpha13<-(alpha1+alpha3)/2
    Sa12 <-(exp(-(alpha1+alpha2)*(max(C)-C)))*(1-exp(-(alpha1+alpha2)*C))/(alpha1+alpha2)
    Sa13 <-(exp(-(alpha1+alpha3)*(max(C)-C)))*(1-exp(-(alpha1+alpha3)*C))/(alpha1+alpha3)
    Sa23 <-(exp(-(alpha2+alpha3)*(max(C)-C)))*(1-exp(-(alpha2+alpha3)*C))/(alpha2+alpha3)
    diag(R)<-sigma.mn.ou
    RkronSa<-rbind(cbind(R.obs.ou[1,1]*Sa1,R.obs.ou[1,2]*Sa12,R.obs.ou[1,3]*Sa13),
                   cbind(R.obs.ou[2,1]*Sa12,R.obs.ou[2,2]*Sa2,R.obs.ou[2,3]*Sa23),
                   cbind(R.obs.ou[3,1]*Sa13,R.obs.ou[3,2]*Sa23,R.obs.ou[3,3]*Sa3))
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs.ou))%*%ginv((RkronSa+m.e))%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant((RkronSa+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs.ou))%*%ginv(RkronSa)%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant(RkronSa)$modulus[1]/2)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  #param<-c(alpha1,alpha2,sigma.mn.ou,0)
  #names(param)<-c("alpha1","alpha2","sigma","R.offd")
  lik.covT.ou<-function(param){
    alpha1<-param["alpha1"]
    alpha2<-param["alpha2"]
    alpha3<-param["alpha3"]
    sigma<-param["sigma"]
    R.offd<-param["R.offd"]
    alpha12<-(alpha1+alpha2)/2
    alpha23<-(alpha2+alpha3)/2
    alpha13<-(alpha1+alpha3)/2
    Sa12 <-(exp(-(alpha1+alpha2)*(max(C)-C)))*(1-exp(-(alpha1+alpha2)*C))/(alpha1+alpha2)
    Sa13 <-(exp(-(alpha1+alpha3)*(max(C)-C)))*(1-exp(-(alpha1+alpha3)*C))/(alpha1+alpha3)
    Sa23 <-(exp(-(alpha2+alpha3)*(max(C)-C)))*(1-exp(-(alpha2+alpha3)*C))/(alpha2+alpha3)
    low.chol<-build.chol(c(sigma,R.offd))
    R<-low.chol%*%t(low.chol)
    RkronSa<-rbind(cbind(R.obs.ou[1,1]*Sa1,R.obs.ou[1,2]*Sa12,R.obs.ou[1,3]*Sa13),
                   cbind(R.obs.ou[2,1]*Sa12,R.obs.ou[2,2]*Sa2,R.obs.ou[2,3]*Sa23),
                   cbind(R.obs.ou[3,1]*Sa13,R.obs.ou[3,2]*Sa23,R.obs.ou[3,3]*Sa3))
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs.ou))%*%ginv((RkronSa+m.e))%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant((RkronSa+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs.ou))%*%ginv(RkronSa)%*%(y-D%*%t(a.obs.ou))/2-N*p*log(2*pi)/2-determinant(RkronSa)$modulus[1]/2)
#    print(LLik)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  sigma.upper<-2*max(apply(x,2,sd))
  p0<-c(alpha1,alpha2,alpha3,sigma.mn.ou)
  names(p0)<-c("alpha1","alpha2","alpha3","sigma")
  if(TraitCov==F){model.ou<-
    optim(p0,fn=lik.covF.ou,method="L-BFGS-B",lower = c(1e-5,1e-5,1e-5,1e-5),upper=c(2,2,2,sigma.upper))
  }

  R.offd<-rep(0,(p*(p-1)/2))
  p0<-c(alpha1,alpha2,alpha3,sigma.mn.ou,0)
  names(p0)<-c("alpha1","alpha2","alpha3","sigma","R.offd")
  if(TraitCov==T){model1.ou<-
    optim(p0,fn=lik.covT.ou,method="L-BFGS-B",lower = c(1e-5,1e-5,1e-5,1e-5,1e-5),upper=c(2,2,2,sigma.upper,sigma.upper))
  }

  if(TraitCov==F){R.constr.ou<-diag(model.ou$par["sigma"],p)}
  if(TraitCov==T){
    chol.mat<-build.chol(model1.ou$par[c("sigma","R.offd")])
    R.constr.ou<-chol.mat%*%t(chol.mat)
  }

  if(model1.ou$convergence==0){
    message.ou<-"Optimization has converged."}else{
      message.ou<-"Optim may not have converrged.
  Consideer changing startt value or lower/upper limits."}

  LRT.ou<-(-2*((-model1.ou$value-LLik.obs.ou)))

  LRT.prob.ou<-pchisq(LRT.ou, (p-1),lower.tail = FALSE)

  AIC.obs.ou<- -2*LLik.obs.ou+2*p+2*p+2*p #(2p twice: 1x for rates, 1x for anc.states,1x for h)
  AIC.common.ou<--2*(-model1.ou$value)+2+2*p+2*p #(2*1:for 1 rate 2p for anc.states)

  return(
    list(
      Robs.ou=R.obs.ou,
      Rconstrained.ou=R.constr.ou,
      Lobs.ou=LLik.obs.ou,
      Lconstrained.ou=(-model1.ou$value),
      LRTest.ou=LRT.ou,
      Prob.ou=LRT.prob.ou,
      AICc.obs.ou=AIC.obs.ou,
      AICc.constrained.ou=AIC.common.ou,
      RkronSa=RkronSa,
      optimmessage.ou=message.ou
    )
  )
}

### Sample code
 phy<-rcoal(5)
 plot(phy)

 x<- matrix(c(rnorm(5,2,1),rnorm(5,0,0.5),rnorm(5,1,1.5)),ncol=3)
 rownames(x)<-phy$tip.label#LETTERS[1:N]
 x
##         [,1]        [,2]       [,3]
## t5 1.4262820 -0.63781404 -0.5401856
## t1 3.8073792 -0.13315301 -0.2331614
## t3 2.9802492 -0.01734169  1.3144293
## t2 3.8860375  0.10674253  1.5774422
## t4 0.9339436  0.98535101  1.5090790
CompareRates.multTraitOU(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)
## Warning in fitContinuous(phy = phy, dat = x[, Index], model = "OU"): 
## Parameter estimates appear at bounds:
##  alpha

## Warning in fitContinuous(phy = phy, dat = x[, Index], model = "OU"): 
## Parameter estimates appear at bounds:
##  alpha
## $Robs.ou
##            [,1]      [,2]     [,3]
## [1,] 13.6851064 0.2933883 2.622895
## [2,]  0.2933883 1.5444703 1.840147
## [3,]  2.6228954 1.8401472 2.735281
## 
## $Rconstrained.ou
##              [,1]         [,2]         [,3]
## [1,] 2.313735e+02 1.521097e-04 1.521097e-04
## [2,] 1.521097e-04 2.313735e+02 1.521098e-04
## [3,] 1.521097e-04 1.521098e-04 2.313735e+02
## 
## $Lobs.ou
## [1] -34.08428
## 
## $Lconstrained.ou
## [1] -17.23016
## 
## $LRTest.ou
## [1] -33.70823
## 
## $Prob.ou
## [1] 1
## 
## $AICc.obs.ou
## [1] 86.16855
## 
## $AICc.constrained.ou
## [1] 48.46032
## 
## $RkronSa
##            t5         t1         t3          t2          t4         t5
## t5 2.49367909 1.49542021 0.49382957 0.000000000 0.000000000 0.05346077
## t1 1.49542021 2.49367909 0.49382957 0.000000000 0.000000000 0.03205958
## t3 0.49382957 0.49382957 2.49367909 0.000000000 0.000000000 0.01058697
## t2 0.00000000 0.00000000 0.00000000 2.493679093 0.394183940 0.00000000
## t4 0.00000000 0.00000000 0.00000000 0.394183940 2.493679093 0.00000000
## t5 0.05346077 0.03205958 0.01058697 0.000000000 0.000000000 0.28143101
## t1 0.03205958 0.05346077 0.01058697 0.000000000 0.000000000 0.16876976
## t3 0.01058697 0.01058697 0.05346077 0.000000000 0.000000000 0.05573249
## t2 0.00000000 0.00000000 0.00000000 0.053460765 0.008450716 0.00000000
## t4 0.00000000 0.00000000 0.00000000 0.008450716 0.053460765 0.00000000
## t5 0.57214638 0.37286757 0.14599796 0.000000000 0.000000000 0.40140127
## t1 0.37286757 0.57214638 0.14599796 0.000000000 0.000000000 0.26159305
## t3 0.14599796 0.14599796 0.57214638 0.000000000 0.000000000 0.10242792
## t2 0.00000000 0.00000000 0.00000000 0.572146380 0.120364392 0.00000000
## t4 0.00000000 0.00000000 0.00000000 0.120364392 0.572146380 0.00000000
##            t1         t3          t2          t4        t5        t1        t3
## t5 0.03205958 0.01058697 0.000000000 0.000000000 0.5721464 0.3728676 0.1459980
## t1 0.05346077 0.01058697 0.000000000 0.000000000 0.3728676 0.5721464 0.1459980
## t3 0.01058697 0.05346077 0.000000000 0.000000000 0.1459980 0.1459980 0.5721464
## t2 0.00000000 0.00000000 0.053460765 0.008450716 0.0000000 0.0000000 0.0000000
## t4 0.00000000 0.00000000 0.008450716 0.053460765 0.0000000 0.0000000 0.0000000
## t5 0.16876976 0.05573249 0.000000000 0.000000000 0.4014013 0.2615930 0.1024279
## t1 0.28143101 0.05573249 0.000000000 0.000000000 0.2615930 0.4014013 0.1024279
## t3 0.05573249 0.28143101 0.000000000 0.000000000 0.1024279 0.1024279 0.4014013
## t2 0.00000000 0.00000000 0.281431010 0.044486712 0.0000000 0.0000000 0.0000000
## t4 0.00000000 0.00000000 0.044486712 0.281431010 0.0000000 0.0000000 0.0000000
## t5 0.26159305 0.10242792 0.000000000 0.000000000 0.7361223 0.5195283 0.2391263
## t1 0.40140127 0.10242792 0.000000000 0.000000000 0.5195283 0.7361223 0.2391263
## t3 0.10242792 0.40140127 0.000000000 0.000000000 0.2391263 0.2391263 0.7361223
## t2 0.00000000 0.00000000 0.401401271 0.084444159 0.0000000 0.0000000 0.0000000
## t4 0.00000000 0.00000000 0.084444159 0.401401271 0.0000000 0.0000000 0.0000000
##            t2         t4
## t5 0.00000000 0.00000000
## t1 0.00000000 0.00000000
## t3 0.00000000 0.00000000
## t2 0.57214638 0.12036439
## t4 0.12036439 0.57214638
## t5 0.00000000 0.00000000
## t1 0.00000000 0.00000000
## t3 0.00000000 0.00000000
## t2 0.40140127 0.08444416
## t4 0.08444416 0.40140127
## t5 0.00000000 0.00000000
## t1 0.00000000 0.00000000
## t3 0.00000000 0.00000000
## t2 0.73612231 0.20324254
## t4 0.20324254 0.73612231
## 
## $optimmessage.ou
## [1] "Optimization has converged."
#####

# tree<-read.tree("http://tonyjhwueng.info/phymvrates/ple.nwk")
# plot(tree)
# tree$tip.label
# # #
# df<-read.csv("http://tonyjhwueng.info/phymvrates/Adams2012-SystBiolData.csv")
# head(df)
# # #
# spX<-strsplit(as.character(df$X),"_")
# spname<-array(NA,length(phy$tip.label))
# for(Index in 1:length(phy$tip.label)){
#   spname[Index]<-paste("Plethodon_", spX[[Index]][2],sep="")
# }
# spname
# df$X<-spname
#
# HeadLength<-df$HeadLength
# names(HeadLength)<-spname
# BodyWidth<-df$BodyWidth
# names(BodyWidth)<-spname
#
# HeadLength<-HeadLength[phy$tip.label]
# BodyWidth<-BodyWidth[phy$tip.label]
# x<-cbind(HeadLength,BodyWidth)
# CompareRates.multTrait(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)