/* Steven Andrews, 1/28/97 - 2008 */
/* See documentation called dynsys2_doc */

#include "dynsys2.h"
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define CHECK(A) if(!(A)) goto failure;


/* dyns_AllocOde */
odeptr dyns_AllocOde(int maxdim) {
	int i;
	odeptr ode;

	ode=(odeptr) malloc(sizeof(struct odestruct));
	if(!ode) return NULL;

	ode->maxdim=maxdim;
	ode->dim=0;
	ode->order=1;
	ode->dtptr=NULL;
	ode->dtsugg=0;
	ode->dtmax=1;
	ode->eps=0.001;
	ode->systemptr=NULL;
	ode->eqm=NULL;
	ode->statenow=NULL;
	ode->statewas=NULL;
  ode->deriv=NULL;
	ode->scale=NULL;
	ode->k1=ode->k2=ode->k3=ode->k4=NULL;

	CHECK(ode->statenow=(double**) calloc(maxdim,sizeof(double*)));
	for(i=0;i<maxdim;i++) ode->statenow[i]=NULL;
	CHECK(ode->statewas=(double**) calloc(maxdim,sizeof(double*)));
	for(i=0;i<maxdim;i++) ode->statewas[i]=NULL;
	CHECK(ode->deriv=(double**) calloc(maxdim,sizeof(double*)));
	for(i=0;i<maxdim;i++) ode->deriv[i]=NULL;
	CHECK(ode->scale=(double*) calloc(maxdim,sizeof(double)));
	for(i=0;i<maxdim;i++) ode->scale[i]=1;

	return ode;

 failure:
	dyns_FreeOde(ode);
	return NULL; }



/* dyns_FreeOde */
void dyns_FreeOde(odeptr ode) {
	if(!ode) return;
	free(ode->k4);
	free(ode->k3);
	free(ode->k2);
	free(ode->k1);
	free(ode->scale);
  free(ode->deriv);
	free(ode->statewas);
	free(ode->statenow);
	free(ode);
	return; }



/* dyns_SetOrder */
int dyns_SetOrder(odeptr ode,int order) {
	int i,maxdim;

	if(!ode) return 2;
	if(!(order==1 || order==2 || order==4 || order==5)) return 2;
	maxdim=ode->maxdim;

	if(order>=2 && !ode->k1) {
		CHECK(ode->k1=(double*) calloc(maxdim,sizeof(double)));
		for(i=0;i<maxdim;i++) ode->k1[i]=0; }
	if(order>=4 && !ode->k2) {
		CHECK(ode->k2=(double*) calloc(maxdim,sizeof(double)));
		for(i=0;i<maxdim;i++) ode->k2[i]=0; }
	if(order>=5 && !ode->k3) {
		CHECK(ode->k3=(double*) calloc(maxdim,sizeof(double)));
		for(i=0;i<maxdim;i++) ode->k3[i]=0; }
	if(order>=5 && !ode->k4) {
		CHECK(ode->k4=(double*) calloc(maxdim,sizeof(double)));
		for(i=0;i<maxdim;i++) ode->k4[i]=0; }
	ode->order=order;
	return 0;

 failure:
	return 1; }



/* dyns_SetParamPtr */
int dyns_SetParamPtr(odeptr ode,char *param,void *value) {
	if(!strcmp(param,"dtptr"))
		ode->dtptr=(double*) value;
	else if(!strcmp(param,"systemptr"))
		ode->systemptr=value;
	else if(!strcmp(param,"eqm"))
		ode->eqm=value;
	else return 2;
	return 0; }



/* dyns_SetParamDbl */
int dyns_SetParamDbl(odeptr ode,char *param,double value) {
	if(value<=0) return 2;
	if(!strcmp(param,"dtsugg"))
		ode->dtsugg=value;
	else if(!strcmp(param,"dtmax"))
		ode->dtmax=value;
	else if(!strcmp(param,"eps"))
		ode->eps=value;
	else return 2;
	return 0; }



/* dyns_AddStatePtr */
int dyns_AddStatePtr(odeptr ode,double *nowptr,double *wasptr,double *derivptr,double scale) {
	int i;

	if(ode->dim==ode->maxdim) return 1;
	i=ode->dim++;
	ode->statenow[i]=nowptr;
	ode->statewas[i]=wasptr;
  ode->deriv[i]=derivptr;
	if(scale>0) ode->scale[i]=scale;
	else ode->scale[i]=1;
	return 0; }



/* dyns_ClearStatePtrs */
void dyns_ClearStatePtrs(odeptr ode) {
	int i;

	for(i=0;i<ode->dim;i++) {
		ode->statenow[i]=NULL;
		ode->statewas[i]=NULL;
    ode->deriv[i]=NULL;
		ode->scale[i]=0; }
	ode->dim=0;
	return; }



/* dyns_StepOde */
int dyns_StepOde(odeptr ode) {
	int dim,order,i,er;
	static int redoctr=0;
	double **s0,**s1,**deriv,dt,dt2,dt3,dt4,dt6,dt12;
	double *k1,*k2,*k3,*k4;
	double diff,diffmax,*scale;
	void *systemptr;

	dim=ode->dim;
	order=ode->order;
	systemptr=ode->systemptr;
	s0=ode->statenow;
	s1=ode->statewas;
  deriv=ode->deriv;
  
	if(order==1) {																	// *** ORDER 1 ***
		dt=*ode->dtptr;
		for(i=0;i<dim;i++) *s1[i]=*s0[i];
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) *s0[i]=*s1[i]+dt**deriv[i]; }
  
	else if(order==2) {															// *** ORDER 2 ***
		dt=*ode->dtptr;
		k1=ode->k1;
		dt2=dt/2.0;
		for(i=0;i<dim;i++) *s1[i]=*s0[i];
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k1[i]=*s1[i];
			*s1[i]+=dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			*s1[i]=k1[i];
			*s0[i]=k1[i]+dt**deriv[i]; }}
  
	else if(order==4) {															// *** ORDER 4 ***
		dt=*ode->dtptr;
		k1=ode->k1;
		k2=ode->k2;
		dt2=dt/2.0;
		dt3=dt/3.0;
		dt6=dt/6.0;
		for(i=0;i<dim;i++) *s1[i]=*s0[i];
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k1[i]=*s1[i];																// k1[i] is old state
			k2[i]=k1[i]+dt6**deriv[i];
			*s1[i]+=dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt3**deriv[i];
			*s1[i]=k1[i]+dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt3**deriv[i];
			*s1[i]=k1[i]+dt**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			*s0[i]=k2[i]+dt6**deriv[i];
			*s1[i]=k1[i]; }}
  
	else if(order==5) {															// *** ORDER 5 ***
		dt=ode->dtsugg>0?ode->dtsugg:*ode->dtptr;
		if(ode->dtmax && dt>ode->dtmax) dt=ode->dtmax;
		scale=ode->scale;
		k1=ode->k1;
		k2=ode->k2;
		k3=ode->k3;
		k4=ode->k4;
		dt2=dt/2.0;
		dt3=dt/3.0;
		dt4=dt/4.0;
		dt6=dt/6.0;
		dt12=dt/12.0;
		for(i=0;i<dim;i++) *s1[i]=*s0[i];
		if((er=(ode->eqm)(systemptr))) return er;			// *** start of dt step
		for(i=0;i<dim;i++) {
			k1[i]=*s1[i];																// k1 is old state
			k3[i]=*deriv[i];
			*s1[i]+=dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]=k1[i]+dt6*k3[i]+dt3**deriv[i];
			*s1[i]=k1[i]+dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt3**deriv[i];
			*s1[i]=k1[i]+dt**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++)
			k4[i]=k2[i]+dt6**deriv[i];                  // k4 is result from dt step

		for(i=0;i<dim;i++)														// start of first dt2 step
			*s1[i]=k1[i]+dt4*k3[i];
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]=k1[i]+dt12*k3[i]+dt6**deriv[i];
			*s1[i]=k1[i]+dt4**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt6**deriv[i];
			*s1[i]=k1[i]+dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++)
			*s1[i]=k3[i]=k2[i]+dt12**deriv[i];					// k3 is result from dt2 step

		if((er=(ode->eqm)(systemptr))) return er;			// start of second dt2 step
		for(i=0;i<dim;i++) {
			k2[i]=k3[i]+dt12**deriv[i];
			*s1[i]+=dt4**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt6**deriv[i];
			*s1[i]=k3[i]+dt4**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			k2[i]+=dt6**deriv[i];
			*s1[i]=k3[i]+dt2**deriv[i]; }
		if((er=(ode->eqm)(systemptr))) return er;
		for(i=0;i<dim;i++) {
			*s1[i]=k1[i];																// s1 is old state
			*s0[i]=k1[i]=k2[i]+dt12**deriv[i]; }				// s0 and k1 are result from second dt2 step

		*ode->dtptr=dt;
		diffmax=diff=0;
		for(i=0;i<dim;i++) {
			diff=fabs(k1[i]-k4[i])/scale[i];            // diff is the scaled difference between one step and two steps for a single value
			if(!(diff+1.0>diff)) break;
			if(diff>diffmax) diffmax=diff; }            // diffmax is the maximum difference between one step and two steps, for all values
		if(!(diff+1.0>diff)) {                        // infinity or NaN so try again
			ode->dtsugg=dt*0.5;
			for(i=0;i<dim;i++) *s0[i]=*s1[i];
			if(redoctr++==20) return -1;
			return dyns_StepOde(ode); }
		else if(diffmax==0)                           // can't use diffmax to scale with, so just increase dt
			ode->dtsugg=dt*1.1;
		else if(!ode->eps)                            // eps wasn't defined, so define it now to the error with the initial step
			ode->eps=diffmax;
		else if(diffmax<=ode->eps)                    // less error than permitted so use longer step
			ode->dtsugg=0.90*dt*exp(-0.20*log(diffmax/ode->eps));
		else {                                        // more error than permitted so use shorter step and repeat
			ode->dtsugg=0.90*dt*exp(-0.25*log(diffmax/ode->eps));
			for(i=0;i<dim;i++) *s0[i]=*s1[i];
			if(redoctr++==20) return -1;
			return dyns_StepOde(ode); }}

		if(ode->dtmax>0 && ode->dtsugg>ode->dtmax) ode->dtsugg=ode->dtmax;
		redoctr=0;
	return 0; }



/******************************************************************************/
/** An example program, which integrates a simple damped harmonic oscillator **/
/******************************************************************************/

typedef struct shostate {
	double positionwas;
	double positionnow;
	double velocitywas;
	double velocitynow;
  double positionderiv;
  double velocityderiv;
	double omega2;
	double gamma;
	double time;
	double dt;
	} *shostateptr;

int ShoEqm(void *system);

int ShoEqm(void *system) {
	shostateptr sys;

	sys=(shostateptr) system;
	sys->positionderiv=sys->velocitywas;
	sys->velocityderiv=-sys->omega2*sys->positionwas-2.0*sys->gamma*sys->velocitywas;
	return 0; }


int dyns_ShoExample(void) {
	odeptr ode;
	shostateptr sys;
	int er,order;

	sys=(shostateptr) malloc(sizeof(struct shostate));
	if(!sys) return 0;
	sys->positionwas=sys->positionnow=1.0;
	sys->velocitywas=sys->velocitynow=-0.1;
	sys->omega2=0.5;
	sys->gamma=0.1;
	sys->time=0;
	sys->dt=0.2;

	ode=dyns_AllocOde(2);
	if(!ode) return 0;
	printf("Enter integrator order: ");
	scanf("%i",&order);
	er=dyns_SetOrder(ode,order);
	if(!er) er=dyns_SetParamPtr(ode,"dtptr",(void*)&sys->dt);
	if(!er) er=dyns_SetParamPtr(ode,"systemptr",(void*)sys);
	if(!er) er=dyns_SetParamPtr(ode,"eqm",(void*)&ShoEqm);
	if(!er) er=dyns_SetParamDbl(ode,"dtmax",1.0/sqrt(sys->omega2));
	if(!er) er=dyns_AddStatePtr(ode,&sys->positionnow,&sys->positionwas,&sys->positionderiv,1);
	if(!er) er=dyns_AddStatePtr(ode,&sys->velocitynow,&sys->velocitywas,&sys->velocityderiv,1);
	if(er) {printf("error\n");return 0;}

	printf("t x v\n");
	for(sys->time=0;sys->time<500;sys->time+=sys->dt) {
		printf("%g %g %g\n",sys->time,sys->positionnow,sys->velocitynow);
		er=dyns_StepOde(ode);
		if(er) break; }
	dyns_FreeOde(ode);
	free(sys);
	return 0; }




