# rm(list=ls())
library(ape)
library(MASS)
library(phytools)
## Loading required package: maps
CompareRates.multTraitPMM<-function(phy,x,TraitCov=T,ms.err=NULL,ms.cov=NULL){
  build.chol<-function(b){
    c.mat<-matrix(0,nrow = p,ncol = p)
    c.mat[lower.tri(c.mat)]<-b[-1]
    c.mat[p,p]<-exp(b[1])
    c.mat[1,1]<-sqrt(sum((c.mat[p,])^2))
    if(p>2){
      for(i in 2:(p-1)){
        c.mat[i,i]<-ifelse((c.mat[1,1]^2-sum((c.mat[i,])^2))>0,sqrt(c.mat[1,1]^2-sum((c.mat[i,])^2)),0)
      }
    }
    return(c.mat)
  }

  x<-as.matrix(x)
  N<-nrow(x)
  p<-ncol(x)
  C<-vcv.phylo(phy)
  C<-C[rownames(x),rownames(x)]

  I<-diag(1,N)

  if (is.matrix(ms.err)){
    ms.err<-as.matrix(ms.err[rownames(x),])}
  if (is.matrix(ms.cov)){
    ms.cov<-as.matrix(ms.cov[rownames(x),])}

  a.obs<-colSums(solve(C)) %*% x / sum(solve(C))
  one<-matrix(1,N,1)

  R.obs<-t(x-one%*%a.obs)%*%solve(C)%*%(x-one%*%a.obs)/N

  IIDcon<-function(trait=trait,mu=mu,sigma=sigma,C=C){
    z<- (trait-mu)/sigma
    eC<-eigen(C)
    D.n5<-diag(1/sqrt(eC$values))
    C.neg.5<-eC$vectors%*%D.n5%*%t(eC$vectors)
    return(C.neg.5%*%trait)
  }

  a.obs.pmm<-array(NA,p)
  R.obs.pmm<-array(NA,c(p,p))
  IIDtrait<-array(NA,dim(x))
  h.array<-array(NA,p)
  for(Index in 1:p){
    h<-phylosig(tree=phy,x=x[,Index],method="lambda")$lambda
    h.array[Index]<-h
    Sh<-h^2*C + (1-h^2)*I
    assign(paste("Sh",Index,sep=""),Sh)
    a.obs.pmm[Index]<-colSums(solve(Sh)) %*% x[,Index] / sum(solve(Sh))
    R.obs.pmm[Index,Index]<-t(x[,Index]-one%*%a.obs.pmm[Index])%*%solve(Sh)%*%(x[,Index]-one%*%a.obs.pmm[Index])/N
    IIDtrait[,Index]<-IIDcon(trait=x[,Index],mu=a.obs.pmm[Index],sigma=R.obs.pmm[Index,Index],C=Sh)
  }


  for(i in 1: (p-1)){
    for( j in (i+1):p){
      R.obs.pmm[i,j]<-R.obs.pmm[j,i]<-  cov(IIDtrait[,i],IIDtrait[,j])
    }
  }

  dim(a.obs.pmm)
  a.obs.pmm<-matrix(a.obs.pmm,nrow=1,ncol=p)
  a.obs.pmm
  dim(a.obs.pmm)
  R.obs.pmm

  D<-matrix(0,N*p,p)
  for(i in 1:(N*p)){
    for(j in 1:p){
      if((j-1)*N < i && i<=j*N){
        D[i,j]=1.0
      }
    }
  }
  y<-as.matrix(as.vector(x))

  if (TraitCov==F){R.obs.pmm<-diag(diag(R.obs.pmm),p)}

  Sh1
  Sh2
  Sh2
  h1<-h.array[1]
  h2<-h.array[2]
  h3<-h.array[3]
  h12<-(h1+h2)/2
  h13<-(h1+h3)/2
  h23<-(h2+h3)/2
  Sh12 <- h12^2*C + (1-h12^2)*I
  Sh13 <- h13^2*C + (1-h13^2)*I
  Sh23 <- h23^2*C + (1-h23^2)*I
  RkronSh<-rbind(cbind(R.obs.pmm[1,1]*Sh1,R.obs.pmm[1,2]*Sh12,R.obs.pmm[1,3]*Sh13),
                 cbind(R.obs.pmm[2,1]*Sh12,R.obs.pmm[2,2]*Sh2,R.obs.pmm[2,3]*Sh23),
                 cbind(R.obs.pmm[3,1]*Sh13,R.obs.pmm[3,2]*Sh23,R.obs.pmm[3,3]*Sh3))
  RkronSh


  LLik.obs.pmm<-ifelse(is.matrix(ms.err)==TRUE,
                       -t(y-D%*%t(a.obs.pmm))%*%ginv((RkronSh+diag(as.vector(ms.err))))%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant((RkronSh+ diag(as.vector(ms.err))))$modulus[1]/2,
                       -t(y-D%*%t(a.obs.pmm))%*%ginv(RkronSh)%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant(RkronSh)$modulus[1]/2)


  sigma.mn.pmm<-mean(diag(R.obs.pmm))
  if(is.matrix(ms.err) && is.matrix(ms.cov)){
    within.spp<-cbind(ms.err,ms.cov)
    rc.label<-NULL
    for(i in 1:p){
      rc.label<-rbind(rc.label,c(i,i))
    }
    for(j in 2:p){
      if(i!=j&&i<j){
        rc.label<-rbind(rc.label,c(i,j))
      }
    }
    m.e<-NULL
    for(i in 1:p){
      temp<-NULL
      for(j in 1:p){
        for(k in 1:nrow(rc.label)){
          if(setequal(c(i,j),rc.label[k,])==T)
          {tmp<-cbind(tmp,diag(within.spp[,k]))}
        }
      }
      m.e<-rbind(m.e,tmp)
    }
  }


  R<-R.obs.pmm
  #param<-c(sigma.mn.pmm,0.5,0.5)
  #names(param)<-c("sigma","h1","h2","h3")
  lik.covF.pmm<-function(param){
    h1<-param["h1"]
    h2<-param["h2"]
    h3<-param["h3"]
    sigma<-param["sigma"]
    h12<-(h1+h2)/2
    h13<-(h1+h3)/2
    h23<-(h2+h3)/2
    Sh12 <- h12^2*C + (1-h12^2)*I
    Sh13 <- h13^2*C + (1-h13^2)*I
    Sh23 <- h23^2*C + (1-h23^2)*I
    diag(R)<-sigma.mn.pmm
    RkronSh<-rbind(cbind(R.obs.pmm[1,1]*Sh1,R.obs.pmm[1,2]*Sh12,R.obs.pmm[1,3]*Sh13),
                   cbind(R.obs.pmm[2,1]*Sh12,R.obs.pmm[2,2]*Sh2,R.obs.pmm[2,3]*Sh23),
                   cbind(R.obs.pmm[3,1]*Sh13,R.obs.pmm[3,2]*Sh23,R.obs.pmm[3,3]*Sh3))
    LLik<-ifelse(is.matrix(ms.err)==TRUE,
                 -t(y-D%*%t(a.obs.pmm))%*%ginv((RkronSh+m.e))%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant((RkronSh+ m.e))$modulus[1]/2,
                 -t(y-D%*%t(a.obs.pmm))%*%ginv(RkronSh)%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant(RkronSh)$modulus[1]/2)
    if(LLik==-Inf){LLik<--1e+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }


  lik.covT.pmm<-function(param){
    h1<-param["h1"]
    h2<-param["h2"]
    h3<-param["h3"]
    sigma<-param["sigma"]
    R.offd<-param["R.offd"]
    h12<-(h1+h2)/2
    h13<-(h1+h3)/2
    h23<-(h2+h3)/2
    Sh12 <- h12^2*C + (1-h12^2)*I
    Sh13 <- h13^2*C + (1-h13^2)*I
    Sh23 <- h23^2*C + (1-h23^2)*I
    low.chol<-build.chol(c(sigma,R.offd))
    R<-low.chol%*%t(low.chol)
    RkronSh<-rbind(cbind(R.obs.pmm[1,1]*Sh1,R.obs.pmm[1,2]*Sh12,R.obs.pmm[1,3]*Sh13),
                   cbind(R.obs.pmm[2,1]*Sh12,R.obs.pmm[2,2]*Sh2,R.obs.pmm[2,3]*Sh23),
                   cbind(R.obs.pmm[3,1]*Sh13,R.obs.pmm[3,2]*Sh23,R.obs.pmm[3,3]*Sh3))
    LLik <- ifelse(is.matrix(ms.err)==TRUE,
                   -t(y-D%*%t(a.obs.pmm))%*%ginv(RkronSh+m.e)%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant(RkronSh+m.e)$modulus[1]/2,
                   -t(y-D%*%t(a.obs.pmm))%*%ginv(RkronSh)%*%(y-D%*%t(a.obs.pmm))/2-N*p*log(2*pi)/2-determinant(RkronSh)$modulus[1]/2
    )
    if(LLik==-Inf){LLik<--le+10}
    if(LLik== Inf){LLik<- 1e+10}
    return(-LLik)
  }

  sigma.upper<-2*max(apply(x,2,sd))
  p0<-c(0.5,0.5,0.5,sigma.mn.pmm)
  names(p0)<-c("h1","h2","h3","sigma")
  if(TraitCov==F){model.pmm<-optim(p0,fn=lik.covF.pmm,method
                                   ="L-BFGS-B",lower = c(0,0,0,0),upper=c(1,1,1,sigma.upper))}

  R.offd<-rep(0,(p*(p-1)/2))
  p0<-c(0.5,0.5,0.5,sigma.mn.pmm,0)
  names(p0)<-c("h1","h2","h3","sigma","R.offd")
  if(TraitCov==T){model1.pmm<-
    optim(par=p0,fn=lik.covT.pmm,method="L-BFGS-B",lower = c(0,0,0,0,0),upper=c(1,1,1,sigma.upper,sigma.upper))
  }

  if(TraitCov==F){R.constr.pmm<-diag(model.pmm$par["sigma"],p)}
  if(TraitCov==T){
    chol.mat<-build.chol(model1.pmm$par[c("sigma","R.offd")])
    R.constr.pmm<-chol.mat%*%t(chol.mat)
  }

  if(model1.pmm$convergence==0){
    message.pmm<-"Optimization has converged."}else{
      message.pmm<-"Optim may not have converrged.
  Consideer changing startt value or lower/upper limits."}

  LRT.pmm<-(-2*((-model1.pmm$value-LLik.obs.pmm)))

  LRT.prob.pmm<-pchisq(LRT.pmm, (p-1),lower.tail = FALSE)

  AIC.obs.pmm<- -2*LLik.obs.pmm+2*p+2*p+2*p #(2p twice: 1x for rates, 1x for anc.states,1x for h)
  AIC.common.pmm<--2*(-model1.pmm$value)+2+2*p+2*p #(2*1:for 1 rate 2p for anc.states)

  return(

    list(
      Robs.pmm=R.obs.pmm,
      Rconstrained.pmm=R.constr.pmm,
      Lobs.pmm=LLik.obs.pmm,
      Lconstrained.pmm=(-model1.pmm$value),
      LRTest.pmm=LRT.pmm,
      Prob.pmm=LRT.prob.pmm,
      AICc.obs.pmm=AIC.obs.pmm,
      AICc.constrained.pmm=AIC.common.pmm,
      RkronSh=RkronSh,
      optimmessage.pmm=message.pmm
    )
  )
}



### Sample code
 phy<-rcoal(5)
 plot(phy)

 x<- matrix(c(rnorm(5,2,1),rnorm(5,0,0.5),rnorm(5,1,1.5)),ncol=3)
 rownames(x)<-phy$tip.label#LETTERS[1:N]
 x
##        [,1]        [,2]       [,3]
## t5 2.222208  0.06964572  0.7181074
## t2 2.118412  0.05805087  1.1441435
## t4 1.147705 -0.79305967  1.0351506
## t1 1.729599 -0.86020561 -0.4626294
## t3 1.599140  0.29699982  1.5556282
 CompareRates.multTraitPMM(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)
## $Robs.pmm
##             [,1]      [,2]        [,3]
## [1,]  0.14874846 0.1302880 -0.03540509
## [2,]  0.13028795 0.2326901  0.28399237
## [3,] -0.03540509 0.2839924  0.46912512
## 
## $Rconstrained.pmm
##          [,1]     [,2]     [,3]
## [1,] 1.763045 0.000000 0.000000
## [2,] 0.000000 1.763045 0.000000
## [3,] 0.000000 0.000000 1.763045
## 
## $Lobs.pmm
## [1] -5.100437
## 
## $Lconstrained.pmm
## [1] -5.100437
## 
## $LRTest.pmm
## [1] -1.754832e-07
## 
## $Prob.pmm
## [1] 1
## 
## $AICc.obs.pmm
## [1] 28.20087
## 
## $AICc.constrained.pmm
## [1] 24.20087
## 
## $RkronSh
##             t5            t2            t4            t1            t3
## t5  0.14874846  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2  0.00000000  1.487485e-01  9.775354e-10  2.371120e-10  2.371120e-10
## t4  0.00000000  9.775354e-10  1.487485e-01  2.371120e-10  2.371120e-10
## t1  0.00000000  2.371120e-10  2.371120e-10  1.487485e-01  1.080705e-09
## t3  0.00000000  2.371120e-10  2.371120e-10  1.080705e-09  1.487485e-01
## t5  0.13028796  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2  0.00000000  1.302880e-01  8.562179e-10  2.076851e-10  2.076851e-10
## t4  0.00000000  8.562179e-10  1.302880e-01  2.076851e-10  2.076851e-10
## t1  0.00000000  2.076851e-10  2.076851e-10  1.302880e-01  9.465836e-10
## t3  0.00000000  2.076851e-10  2.076851e-10  9.465836e-10  1.302880e-01
## t5 -0.03540509  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2  0.00000000 -3.540509e-02 -2.326728e-10 -5.643735e-11 -5.643735e-11
## t4  0.00000000 -2.326728e-10 -3.540509e-02 -5.643735e-11 -5.643735e-11
## t1  0.00000000 -5.643735e-11 -5.643735e-11 -3.540509e-02 -2.572293e-10
## t3  0.00000000 -5.643735e-11 -5.643735e-11 -2.572293e-10 -3.540509e-02
##           t5           t2           t4           t1           t3          t5
## t5 0.1302880 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 -0.03540509
## t2 0.0000000 1.302880e-01 8.562179e-10 2.076851e-10 2.076851e-10  0.00000000
## t4 0.0000000 8.562179e-10 1.302880e-01 2.076851e-10 2.076851e-10  0.00000000
## t1 0.0000000 2.076851e-10 2.076851e-10 1.302880e-01 9.465836e-10  0.00000000
## t3 0.0000000 2.076851e-10 2.076851e-10 9.465836e-10 1.302880e-01  0.00000000
## t5 0.2326901 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00  0.28399237
## t2 0.0000000 2.326901e-01 1.529177e-09 3.709188e-10 3.709188e-10  0.00000000
## t4 0.0000000 1.529177e-09 2.326901e-01 3.709188e-10 3.709188e-10  0.00000000
## t1 0.0000000 3.709188e-10 3.709188e-10 2.326901e-01 1.690568e-09  0.00000000
## t3 0.0000000 3.709188e-10 3.709188e-10 1.690568e-09 2.326901e-01  0.00000000
## t5 0.2839924 0.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00  0.46912512
## t2 0.0000000 2.839924e-01 1.866322e-09 4.526970e-10 4.526970e-10  0.00000000
## t4 0.0000000 1.866322e-09 2.839924e-01 4.526970e-10 4.526970e-10  0.00000000
## t1 0.0000000 4.526970e-10 4.526970e-10 2.839924e-01 2.063295e-09  0.00000000
## t3 0.0000000 4.526970e-10 4.526970e-10 2.063295e-09 2.839924e-01  0.00000000
##               t2            t4            t1            t3
## t5  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2 -3.540509e-02 -2.326728e-10 -5.643735e-11 -5.643735e-11
## t4 -2.326728e-10 -3.540509e-02 -5.643735e-11 -5.643735e-11
## t1 -5.643735e-11 -5.643735e-11 -3.540509e-02 -2.572293e-10
## t3 -5.643735e-11 -5.643735e-11 -2.572293e-10 -3.540509e-02
## t5  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2  2.839924e-01  1.866322e-09  4.526970e-10  4.526970e-10
## t4  1.866322e-09  2.839924e-01  4.526970e-10  4.526970e-10
## t1  4.526970e-10  4.526970e-10  2.839924e-01  2.063295e-09
## t3  4.526970e-10  4.526970e-10  2.063295e-09  2.839924e-01
## t5  0.000000e+00  0.000000e+00  0.000000e+00  0.000000e+00
## t2  4.691251e-01  3.082966e-09  7.478072e-10  7.478072e-10
## t4  3.082966e-09  4.691251e-01  7.478072e-10  7.478072e-10
## t1  7.478072e-10  7.478072e-10  4.691251e-01  3.408344e-09
## t3  7.478072e-10  7.478072e-10  3.408344e-09  4.691251e-01
## 
## $optimmessage.pmm
## [1] "Optimization has converged."
#####

# tree<-read.tree("http://tonyjhwueng.info/phymvrates/ple.nwk")
# plot(tree)
# tree$tip.label
# # #
# df<-read.csv("http://tonyjhwueng.info/phymvrates/Adams2012-SystBiolData.csv")
# head(df)
# spX<-strsplit(as.character(df$X),"_")
# spname<-array(NA,length(phy$tip.label))
# for(Index in 1:length(phy$tip.label)){
#   spname[Index]<-paste("Plethodon_", spX[[Index]][2],sep="")
# }
# spname
# df$X<-spname
#
# HeadLength<-df$HeadLength
# names(HeadLength)<-spname
# BodyWidth<-df$BodyWidth
# names(BodyWidth)<-spname
#
# HeadLength<-HeadLength[phy$tip.label]
# BodyWidth<-BodyWidth[phy$tip.label]
# x<-cbind(HeadLength,BodyWidth)
# CompareRates.multTrait(phy=phy,x=x,TraitCov=T,ms.err=NULL,ms.cov=NULL)