rm(list=ls())
library(phangorn)
## Loading required package: ape
library(ape)
library(sde)
## Loading required package: MASS
## Loading required package: stats4
## Loading required package: fda
## Loading required package: splines
## Loading required package: Matrix
## 
## Attaching package: 'fda'
## The following object is masked from 'package:graphics':
## 
##     matplot
## Loading required package: zoo
## 
## Attaching package: 'zoo'
## The following objects are masked from 'package:base':
## 
##     as.Date, as.Date.numeric
## sde 2.0.15
## Companion package to the book
## 'Simulation and Inference for Stochastic Differential Equations With R Examples'
## Iacus, Springer NY, (2008)
## To check the errata corrige of the book, type vignette("sde.errata")
library(phytools)
## Loading required package: maps
library(phyclust)

size<-5000
X.bm<-Y.bm<-array(NA,size)
X.bm[1]<-Y.bm[1]<-0
for(i in 2:size){
  X.bm[i]<-X.bm[i-1]+rnorm(1)
  Y.bm[i]<-Y.bm[i-1]+rnorm(1)
}
plot(Y.bm~X.bm,type="l",main="BM 2D")
#?points
points(X.bm[1],Y.bm[1],col="blue",pch=17,cex=2)
points(X.bm[size],Y.bm[size],col="red",pch=16,cex=2)

a.x = 0.12
m.x =  0.5
a.y =  0.16
m.y =  0.12
s.x = 0.1
s.y =0.15
r.x = 0.1
r.y = 0.05


X.ou<-Y.ou<-array(NA,size)
X.ou[1]<-Y.ou[1]<-0
for(i in 2:size){
  X.ou[i]<-X.ou[i-1]+a.x*(m.x-X.ou[i-1])+s.x*rnorm(1)
  Y.ou[i]<-Y.ou[i-1]+a.y*(m.y-Y.ou[i-1])+s.y*rnorm(1)
}

# plot(X.ou,type="l")
# points(X.ou[1],X.ou[1],col="red",pch=17,cex=2)
# points(size,X.ou[size],col="blue",pch=16,cex=2)
# 
# plot(Y.ou,type="l")
# points(Y.ou[1],Y.ou[1],col="red",pch=17,cex=2)
# points(size,Y.ou[size],col="blue",pch=16,cex=2)


#?points
plot(Y.ou~X.ou,type="l",main="OU 2D")
points(X.ou[1],Y.ou[1],col="blue",pch=17,cex=2)
points(X.ou[size],Y.ou[size],col="red",pch=16,cex=2)

X.eb<-Y.eb<-array(NA,size)
X.eb[1]<-Y.eb[1]<-0
for(i in 2:size){
  X.eb[i]<-X.eb[i-1]+s.x*exp(r.x)*rnorm(1)
  Y.eb[i]<-Y.eb[i-1]+s.y*exp(r.y)*rnorm(1)
}

plot(Y.eb~X.eb,type="l",main="EB 2D")
points(X.eb[1],Y.eb[1],col="blue",pch=17,cex=2)
points(X.eb[size],Y.eb[size],col="red",pch=16,cex=2)

par(mfrow=c(1,3))
par(mar=c(2,2,2,2))
plot(Y.bm~X.bm,type="l",main="BM 2D")
points(X.bm[1],Y.bm[1],col="blue",pch=17,cex=2)
points(X.bm[size],Y.bm[size],col="red",pch=16,cex=2)

plot(Y.ou~X.ou,type="l",main="OU 2D")
points(X.ou[1],Y.ou[1],col="blue",pch=17,cex=2)
points(X.ou[size],Y.ou[size],col="red",pch=16,cex=2)


plot(Y.eb~X.eb,type="l",main="EB 2D")
points(X.eb[1],Y.eb[1],col="blue",pch=17,cex=2)
points(X.eb[size],Y.eb[size],col="red",pch=16,cex=2)

sim.bm.one.path<-function(model.params,T=T,N=N,x0=x0){
  sigma<-model.params[3]
  dw<-rnorm(N,0,sqrt(T/N))
  path<-c(x0)
  for(Index in 2:(N+1)){
    path[Index]<-path[Index-1]+sigma*dw[Index-1]
  }
  return(path)
}

sim.bm.tree.path<-function(model.params,phy=phy,N=N,root=root){
  sim.node.data<-integer(length(phy$tip.label)+phy$Nnode)
  edge.number<-dim(phy$edge)[1]
  edge.length<-phy$edge.length
  ntips<-length(phy$tip.label)
  ROOT<-ntips+1
  anc<-phy$edge[,1]
  des<-phy$edge[,2]
  path.objects<-list()
  sim.node.data[ROOT]<-root
  for(edgeIndex in edge.number:1){
    brnlen<-edge.length[edgeIndex]
    start.state<-sim.node.data[anc[edgeIndex]]
    print(start.state)
    assign(paste("path",edgeIndex,sep=""),sim.bm.one.path(model.params,T=brnlen,N=brnlen,x0=start.state))
    temp.path<-get(paste("path", edgeIndex,sep=""))
    sim.node.data[des[edgeIndex]]<-temp.path[length(temp.path)]
    path.objects<-c(path.objects,list(get(paste("path",edgeIndex,sep=""))))
  }
  return(list(path.objects=path.objects, sim.node.data=sim.node.data))
}


sim.ou.one.path <- function(model.params, T=T, N=N, x0=x0){
  mu <- model.params[1]
  alpha <- model.params[2]
  sigma <- model.params[3]
  dw <- rnorm(N, 0, sqrt(T/N))
  dt <- T/N
  path<-c(x0)
  for(Index in 2:(N+1)){
    path[Index] <- path[Index-1] + alpha*(mu-path[Index-1])*dt + sigma*dw[Index-1]
  }
  return(ts(path, start=x0, deltat=dt))
}

sim.ou.tree.path<-function(model.params,phy=phy,N=N,root=root){
  sim.node.data<-integer(length(phy$tip.label)+phy$Nnode)
  edge.number<-dim(phy$edge)[1]
  edge.length<-phy$edge.length
  ntips<-length(phy$tip.label)
  ROOT<-ntips+1
  anc<- phy$edge[,1]
  des<- phy$edge[,2]
  path.objects<-list()
  sim.node.data[ROOT]<-root
  for(edgeIndex in edge.number:1){
    brnlen<-edge.length[edgeIndex]
    start.state<-sim.node.data[anc[edgeIndex]]
    assign(paste("path",edgeIndex,sep=""), sim.ou.one.path(model.params,T=brnlen,N=brnlen,x0=start.state))
    temp.path<-get(paste("path",edgeIndex,sep=""))
    sim.node.data[des[edgeIndex]]<-temp.path[length(temp.path)]
    #print(sim.node.data[des[edgeIndex]])
    path.objects<-c(path.objects,list(get(paste("path",edgeIndex,sep=""))))
  }
  return(list(path.objects=path.objects,sim.node.data=sim.node.data))
}




sim.eb.one.path <- function(model.params, T=T, N=N, x0=x0){
  sigma <- model.params[3]
  r <- model.params[4]
  dw <- rnorm(N, 0, sqrt(T/N))
  dt <- T/N
  path<-c(x0)
  for(Index in 2:(N+1)){
    path[Index] <- path[Index-1] + sigma*exp(r)*dw[Index-1]
  }
  return(ts(path, start=x0, deltat=dt))
}

sim.eb.tree.path<-function(model.params,phy=phy,N=N,root=root){
  sim.node.data<-integer(length(phy$tip.label)+phy$Nnode)
  edge.number<-dim(phy$edge)[1]
  edge.length<-phy$edge.length
  ntips<-length(phy$tip.label)
  ROOT<-ntips+1
  anc<- phy$edge[,1]
  des<- phy$edge[,2]
  path.objects<-list()
  sim.node.data[ROOT]<-root
  for(edgeIndex in edge.number:1){
    brnlen<-edge.length[edgeIndex]
    start.state<-sim.node.data[anc[edgeIndex]]
    assign(paste("path",edgeIndex,sep=""), sim.ou.one.path(model.params,T=brnlen,N=brnlen,x0=start.state))
    temp.path<-get(paste("path",edgeIndex,sep=""))
    sim.node.data[des[edgeIndex]]<-temp.path[length(temp.path)]
    #print(sim.node.data[des[edgeIndex]])
    path.objects<-c(path.objects,list(get(paste("path",edgeIndex,sep=""))))
  }
  return(list(path.objects=path.objects,sim.node.data=sim.node.data))
}



plot.history.dt<-function(phy=phy,path.data=path.data,main=main,colors=colors){
  #path.data<-bm.path.data
  #main="BM"
  ntips<-length(phy$tip.label)
  edge.number<-dim(phy$edge)[1]
  x.start<-array(0,c(edge.number,1))
  x.end<-array(0,c(edge.number,1))
  step<-array(0,c(edge.number,1))
  anc.des.start.end<-cbind(phy$edge,step,x.start, x.end)
  colnames(anc.des.start.end)<-c("anc","des","step","x.start","x.end")
  anc.des.start.end<-apply(anc.des.start.end,2,rev)
  anc.des.start.end<-data.frame(anc.des.start.end)
  anc<- anc.des.start.end$anc
  des<- anc.des.start.end$des
  for(edgeIndex in 1:edge.number){
    path<-unlist(path.data$path.objects[edgeIndex])
    #print(length(path))
    anc.des.start.end$step[edgeIndex]<-length(path)
  }
  plot( NA,type="l",xlim=c(1,ceiling(get.rooted.tree.height(phy))+1 ),ylim=c(min(unlist(path.data)), max(unlist(path.data))),ylab="Trait value", xlab="Time Steps")
  abline(a=1e-8,b=0,lty=2,lwd=2)
  for(edgeIndex in 1:edge.number){
    if(anc[edgeIndex]== Ntip(phy)+1){
      anc.des.start.end$x.start[edgeIndex]<-  round(nodeheight(phy,node=anc.des.start.end$anc[edgeIndex]))
    }else{
      anc.des.start.end$x.start[edgeIndex]<- ceiling(nodeheight(phy,node=anc.des.start.end$anc[edgeIndex]))
    }
    anc.des.start.end$x.end[edgeIndex]<-  ceiling(nodeheight(phy,node=anc.des.start.end$des[edgeIndex]))
  }
  
  anc.des.start.end
  max.path.value<- -Inf
  for(edgeIndex in 1:edge.number){
    
    path<-unlist(path.data$path.objects[edgeIndex])
    max.path.value<-max(max.path.value,max(path))
    #starting.x<-ceiling(nodeheight(phy,node=anc[edgeIndex]))
    x.start <-  anc.des.start.end$x.start[edgeIndex]
    x.end   <-  anc.des.start.end$x.end[edgeIndex]
    gap <- round(anc.des.start.end$step[edgeIndex]/( x.end-x.start))
    #points((starting.x+1): (starting.x+length(path)),path[seq(1, anc.des.start.end$step[edgeIndex], by = gap)],type="l")
    point.to.use<-(anc.des.start.end$x.start[edgeIndex]: anc.des.start.end$x.end[edgeIndex])
    sample.path<-path[seq(1,  anc.des.start.end$step[edgeIndex],by=gap)]
    #        print("before")
    #        print( c(length(point.to.use), length(sample.path) ) )
    if(length(point.to.use)!=length(sample.path)){
      if(length(point.to.use) > length(sample.path)){
        #point.to.use<-point.to.use[1:length(sample.path)]
        sample.path<-c(sample.path,sample.path[length(sample.path)])
      }else{
        #sample.path<-sample.path[1:length(point.to.use)]
        point.to.use<-c(point.to.use, point.to.use[length(point.to.use)]+1 )
      }
    }
    #        print("After")
    #        print( c(length(point.to.use), length(sample.path) ) )
    #print(c(length(point.to.use), length(sample.path)) )
    lines(point.to.use , sample.path, type="l",lwd=2,col=colors[edgeIndex])
    #path[seq(1, anc.des.start.end$step[edgeIndex], by = gap)  # we need to find a better way to sample this and get right number with x - axis
    #path[sample(    )    ]  LOOK UP SAMPLE FOR GOOD
  }
  #    abline(v=ceiling(get.rooted.tree.height(phy))+1)
  text(x=0.1*ceiling(get.rooted.tree.height(phy)),y= 0.9*max.path.value,labels=main, lwd=2,lty=1,cex=2)
}#end of plot history


T<-N<-1000

x0<-0
true.alpha<-0.01
true.mu<-1
true.sigma<-0.1
true.r<-0.5
model.params<-c(true.mu,true.alpha,true.sigma,true.r)
#plot(sim.cir.path(model.params, T=T, N=N,x0=0.1))
root<-0
# we can do sims to different regimes
size<-3
phy<-rcoal(size)
#phy<-rtree(size)
phy<-reorder(phy,"postorder")
min.length<-10
while(min(phy$edge.length)<min.length){
  phy$edge.length<-1.005*phy$edge.length
}

phy$edge.length<- round(T/get.rooted.tree.height(phy)*phy$edge.length)
phy$tip.label<-c("   y","   x","   z")

#cir.path.data<-sim.cir.tree.path(model.params,phy=phy,N=N,root=root)

#print(path.data)

#setwd("~/GitHub/BridgePCM")
#pdf("tree.path.plot.pdf")
colors<-c("black","orange","blue","red")

op<-par(mfrow=c(2,2),
        oma=c(2,2,0,0)+0.1,
        mar=c(1,1,1,1)+0.1        
)

C<-vcv(phy)/max(vcv(phy))
D = 2*(max(C)-C)
phy1<-upgma(D)

plot(phy1,edge.width=5,cex=1.5,edge.color=rev(colors))
#tiplabels(pch=21, col="black", adj=1, bg="black", cex=2)
axisPhylo(1, root.time = NULL, backward = F)
# OU 2D tree plot 

bm.path.data.x<-sim.bm.tree.path(model.params,phy=phy,N=N,root=root)
## [1] 0
## [1] 0
## [1] -2.177261
## [1] -2.177261
bm.path.data.y<-sim.bm.tree.path(model.params,phy=phy,N=N,root=root)
## [1] 0
## [1] 0
## [1] -2.234971
## [1] -2.234971
ou.path.data.x<-sim.ou.tree.path(model.params,phy=phy,N=N,root=root)
ou.path.data.y<-sim.ou.tree.path(model.params,phy=phy,N=N,root=root)

eb.path.data.x<-sim.eb.tree.path(model.params,phy=phy,N=N,root=root)
eb.path.data.y<-sim.eb.tree.path(model.params,phy=phy,N=N,root=root)


xlims<-range(unlist(bm.path.data.x$path.objects),unlist(ou.path.data.x$path.objects),unlist(eb.path.data.x$path.objects))
ylims<-range(unlist(bm.path.data.y$path.objects),unlist(ou.path.data.y$path.objects),unlist(eb.path.data.y$path.objects))
  
plot(bm.path.data.y$path.objects[[1]]~bm.path.data.x$path.objects[[1]],col=1,xlim=xlims,ylim=ylims,type="l",xlab="",ylab="",lwd=1)
#points(bm.path.data.x$path.objects[[1]][1],bm.path.data.y$path.objects[[1]][1],col="black",pch=16,cex=2)
for(Index in 2:dim(phy$edge)[1]){
  points(bm.path.data.y$path.objects[[Index]]~bm.path.data.x$path.objects[[Index]],col=colors[Index],type="l",lwd=1)
}

text(0.8*xlims[1],0.8*ylims[2],"BM 2D",cex=2)

# OU 2D tree plot 

plot(ou.path.data.y$path.objects[[1]]~ou.path.data.x$path.objects[[1]],col=1,xlim=xlims,ylim=ylims,type="l",xlab="",ylab="",lwd=1)
#points(ou.path.data.x$path.objects[[1]][1],ou.path.data.y$path.objects[[1]][1],col="black",pch=16,cex=2)
for(Index in 2:dim(phy$edge)[1]){
  points(ou.path.data.y$path.objects[[Index]]~ou.path.data.x$path.objects[[Index]],col=colors[Index],type="l",lwd=1)
}

text(0.8*xlims[1],0.8*ylims[2],"OU 2D",cex=2)


# EB 2D tree plot 

plot(eb.path.data.y$path.objects[[1]]~eb.path.data.x$path.objects[[1]],col=1,xlim=xlims,ylim=ylims,type="l",xlab="",ylab="",lwd=1)
#points(ou.path.data.x$path.objects[[1]][1],ou.path.data.y$path.objects[[1]][1],col="black",pch=16,cex=2)
for(Index in 2:dim(phy$edge)[1]){
  points(eb.path.data.y$path.objects[[Index]]~eb.path.data.x$path.objects[[Index]],col=colors[Index],type="l",lwd=1)
}

text(0.8*xlims[1],0.8*ylims[2],"EB 2D",cex=2)