/* Steven Andrews, 1/00 */
/* See documentation called Quantum doc */
/* Copyright 2003 by Steven Andrews.  Permission is granted
   for non-commercial use of and modifications to the code. */

#include <math.h>
#include <stdio.h>
#include "dynsys.h"
#include "Plot.h"
#include "Cn.h"
#include "math2.h"
#include "Rn.h"
#include "Constants.h"
#include "random.h"
#include "Spectra.h"
#include "Quantum.h"

float psiHO(float x,int n,float k,float m)
{
	float alph;

	alph=sqrt(k*m)/hbar_SSI;
	return sqrt(sqrt(4*PI/alph)/pow(2,n)/factorial(n))*hermite(x*sqrt(alph),n)*gauss(x,0,1/sqrt(alph));
}

complex Bracket(float *psib,float *psik,float *xr,int n)
{
	float *v1,*v2,*v3,*v1r;
	complex ans;
	
	v1=allocV(2*n);
	v2=allocV(2*n);
	v3=allocV(2*n);
	v1r=allocV(n);
	multCV(CompConj(psib,v1,n),psik,v2,n);
	integCV(v2,v1,n);
	deriv1V(xr,v1r,n);
	multCV(v1,makecmplx(v1r,NULL,v3,n),v2,n);
	ans.r=v2[2*n-2];
	ans.i=v2[2*n-1];
	freeV(v1);
	freeV(v2);
	freeV(v3);
	freeV(v1r);
	return ans;
}

float NormalKet(float *psi,float *xr,int n)
{
	complex z;
	float s;

	z=Bracket(psi,psi,xr,n);
	multKV(s=pow(sqr(z.r)+sqr(z.i),-0.25),psi,psi,2*n);
	return s;
}

complex Hamiltonian(float *psib,float *ur,float *psik,float *xr,int n,float m)
{
	float *v1,*v2,*v3,*v4;
	float *v1r,*v2r;
	complex ans;
	
	v1=allocV(2*n);
	v2=allocV(2*n);
	v3=allocV(2*n);
	v4=allocV(2*n);
	v1r=allocV(n);
	v2r=allocV(n);
	multCV(makecmplx(ur,NULL,v2,n),psik,v1,n);				// v1=ur|psik>
	deriv2CV(psik,v2,n,1);														// v2=¶^2 |psik>
	deriv1V(xr,v1r,n);																// v1r=¶x
	multV(v1r,v1r,v2r,n);															// v2r=¶x^2
	divKV(1,v2r,v2r,n);																// v2r=1/¶x^2
	multCV(v2,makecmplx(v2r,NULL,v3,n),v4,n);					// v4=¶^2/¶x^2 |psik>
	sumV(-sqr(hbar_SSI)/(2*m),v4,1,v1,v2,2*n);				// v2=H|psik>
	CompConj(psib,v1,n);															// v1=psib*
	multCV(v1,v2,v3,n);																// v3=psib* H psik
	integCV(v3,v1,n);																	// v1=ºpsib* H psik
	multCV(v1,makecmplx(v1r,NULL,v3,n),v2,n);					// v2=ºpsib* H psik dx
	ans.r=v2[2*n-2];																	// ans=<psib|H|psik>
	ans.i=v2[2*n-1];																	// ans=<psib|H|psik>
	freeV(v1);
	freeV(v2);
	freeV(v3);
	freeV(v4);
	freeV(v1r);
	freeV(v2r);
	return ans;
}

complex Dipole(float *psib,float *psik,float *xr,int n,float q)
{
	float *v1,*v2,*v3,*v1r;
	complex ans;
	
	v1=allocV(2*n);
	v2=allocV(2*n);
	v3=allocV(2*n);
	v1r=allocV(n);
	multCV(makecmplx(xr,NULL,v2,n),psik,v1,n);				// v1=x|psik>
	multCV(CompConj(psib,v2,n),v1,v3,n);							// v3=psib* x psik
	integCV(v3,v1,n);																	// v1=º psib* x psik
	deriv1V(xr,v1r,n);																// v1r=¶x
	multCV(v1,makecmplx(v1r,NULL,v3,n),v2,n);					// v2=ºpsib* x psik dx
	ans.r=q*v2[2*n-2];																// ans=<psib|qx|psik>
	ans.i=q*v2[2*n-1];
	freeV(v1);
	freeV(v2);
	freeV(v3);
	freeV(v1r);
	return ans;
}

void EQMFschrod(phptr u,void *k,phptr dudt)
{
	float *psi,*v1,*v2,*v3;
	int n;
	struct param {float *pot;float dx2;float w;float *v1;float *v2;float *v3;} *eqmparam;
	
	eqmparam=(struct param *) k;
	n=u->fs/2;
	psi=u->fa;
	v1=eqmparam->v1;
	v2=eqmparam->v2;
	v3=eqmparam->v3;
	multCV(eqmparam->pot,psi,v1,n);										// v1=ur/hbar |psik>
	deriv2CV(psi,v2,n,1);															// v2=¶^2 |psik>
	multKV(eqmparam->dx2,v2,v3,2*n);									// v3=-hbar/2m ¶^2/¶x^2 |psik>
	sumV(1,v3,1,v1,v2,2*n);														// v2=1/hbar H|psik>
	rotateCV(v2,dudt->fa,n,1);												// ¶|psi>/¶t =i/hbar H|psi>
	dudt->sa[0]=eqmparam->w;													// overall phase factor
	return;
}

int TKFplotpsi(float t,phptr u,void *tkptr)
{
	struct param2 {float dt;float *xr;float *v1r;float *v1;} *tkfparam;
	int n;

	n=u->fs/2;
	tkfparam=(struct param2 *) tkptr;
	if(t-floor(t)<tkfparam->dt)	{
		rotate2CV(u->fa,tkfparam->v1,n,u->sa[0]);
		PlotClear();
		SetColor('B');
		PlotData2(real(tkfparam->v1,tkfparam->v1r,n),tkfparam->xr,n,3);
		SetColor('R');
		PlotData2(imag(tkfparam->v1,tkfparam->v1r,n),tkfparam->xr,n,3);
		}
	return 0;
}

float *TimeEvolve(float *psi,float *ur,float m,float *xr,int n,float Dt)
{
	phptr ket;
	float intpar[MaxIntPar];
	struct param {float *pot;float dx2;float w;float *v1;float *v2;float *v3;} *eqmparam;
	struct param2 {float dt;float *xr;float *v1r;float *v1;} *tkfparam;
	float *uc,*v1,*v2,*v3,*v1r,dx2,w;

	uc=allocV(2*n);
	v1=allocV(2*n);
	v2=allocV(2*n);
	v3=allocV(2*n);
	v1r=allocV(n);
	eqmparam=malloc(sizeof(struct param));
	w=Hamiltonian(psi,ur,psi,xr,n,m).r/hbar_SSI;
	addKV(-w*hbar_SSI,ur,v1r,n);
	eqmparam->w=w;
	multKV(1/hbar_SSI,v1r,v1r,n);
	makecmplx(v1r,NULL,uc,n);
	eqmparam->pot=uc;																	// potential is [U(x)-<E>]/hbar
	dx2=(xr[n-1]-xr[0])/(n-1);												// assumes equally spaced points
	dx2=-hbar_SSI/(2*m*sqr(dx2));											// dx2=-hbar/(2m ¶x^2)
	eqmparam->dx2=dx2;
	eqmparam->v1=v1;
	eqmparam->v2=v2;
	eqmparam->v3=v3;
	tkfparam=malloc(sizeof(struct param2));
	tkfparam->dt=intpar[0]=0.01;											// time step in fs
	tkfparam->v1r=v1r;
	tkfparam->v1=v1;
	tkfparam->xr=xr;
	
	ket=phptalloc(1,1,2*n);
	copyV(psi,ket->fa,2*n);
	ket->sa[0]=0;
	ODEFrk4(EQMFschrod,(void *) eqmparam,ket,&Dt,intpar,TKFplotpsi,(void *) tkfparam);
	rotate2CV(ket->fa,psi,n,ket->sa[0]);
	phptfree(ket);
	free(eqmparam);
	free(tkfparam);
	freeV(uc);
	freeV(v1);
	freeV(v2);
	freeV(v3);
	freeV(v1r);
	return psi;	
}

int TKFdipcorr(float t,phptr u,void *tkptr)
{
	struct param3 {float dt;float *xr;float *bra;float wb;float prefact;float *dip;float *v1;} *tkfparam;
	int n;
	float *v1;
	complex ans;

	tkfparam=(struct param3 *) tkptr;
	if(t-floor(t)<tkfparam->dt)	{
		n=u->fs/2;
		v1=tkfparam->v1;
		rotate2CV(u->fa,v1,n,u->sa[0]-tkfparam->wb*t);
		ans=Dipole(tkfparam->bra,v1,tkfparam->xr,n,tkfparam->prefact);
		tkfparam->dip[((int) t)*2]=ans.r;
		tkfparam->dip[((int) t)*2+1]=ans.i;
		SetColor('B');
		PlotPt(t,ans.r);
		SetColor('R');
		PlotPt(t,ans.i);
		}
	return 0;
}


void DipCorr(float *psib,float *psik,float *ur,float m,float q,float *xr,int n,float Dt,float *dip)
{
	phptr ket;
	float intpar[MaxIntPar];
	struct param {float *pot;float dx2;float w;float *v1;float *v2;float *v3;} *eqmparam;
	struct param3 {float dt;float *xr;float *bra;float wb;float prefact;float *dip;float *v1;} *tkfparam;
	float *uc,*v1,*v2,*v3,*v1r,dx2,wk,wb;
	float *ketv,normk;
	int i;

	uc=allocV(2*n);
	v1=allocV(2*n);
	v2=allocV(2*n);
	v3=allocV(2*n);
	v1r=allocV(n);
	ketv=allocV(2*n);
	eqmparam=malloc(sizeof(struct param));
	tkfparam=malloc(sizeof(struct param3));

	for(i=0;i<n*2;i++)	ketv[i]=xr[i/2]*psik[i];
	normk=NormalKet(ketv,xr,n);
	wb=Hamiltonian(psib,ur,psib,xr,n,m).r/hbar_SSI;
	wk=Hamiltonian(psik,ur,psik,xr,n,m).r/hbar_SSI;
	addKV(-wk*hbar_SSI,ur,v1r,n);
	eqmparam->w=wk;
	multKV(1/hbar_SSI,v1r,v1r,n);
	makecmplx(v1r,NULL,uc,n);
	eqmparam->pot=uc;																	// potential is [U(x)-<E>]/hbar
	dx2=(xr[n-1]-xr[0])/(n-1);												// assumes equally spaced points
	dx2=-hbar_SSI/(2*m*sqr(dx2));											// dx2=-hbar/(2m ¶x^2)
	eqmparam->dx2=dx2;
	eqmparam->v1=v1;
	eqmparam->v2=v2;
	eqmparam->v3=v3;
	tkfparam->dt=intpar[0]=0.01;											// time step in fs
	tkfparam->xr=xr;
	tkfparam->bra=psib;
	tkfparam->wb=wb;
	tkfparam->prefact=sqr(q)/normk;
	tkfparam->dip=dip;
	tkfparam->v1=v1;
	ket=phptalloc(1,1,2*n);
	copyV(ketv,ket->fa,2*n);
	ket->sa[0]=0;
	ODEFrk4(EQMFschrod,(void *) eqmparam,ket,&Dt,intpar,TKFdipcorr,(void *) tkfparam);

	phptfree(ket);
	free(eqmparam);
	free(tkfparam);
	freeV(uc);
	freeV(v1);
	freeV(v2);
	freeV(v3);
	freeV(v1r);
	freeV(ketv);
	return;
}

float EigenketAHO(float *xr,float *ur,float *psi,int n,float k,float m)
{
	float *v1,*psitry;
	float *expans,*step;								// expans is list of expansion parameters
	float *b;														// matrix of complex basis kets 
	float e,e2,oldex;
	int i,j,nb,it;

	nb=10;															// number of basis kets
	v1=allocV(2*n);
	psitry=allocV(2*n);
	b=allocV(2*n*nb);										// n complex rows, nb columns
	expans=allocV(nb);
	step=allocV(nb);
	setstdM(b,2*n,nb,0);
	for(j=0;j<nb;j++)
		for(i=0;i<n;i++)
			b[nb*2*i+j]=psiHO(xr[i],j,k,m);
	setstdV(expans,nb,0);
	expans[0]=1;												// starting ket is HO ground state
	columnM(b,psi,2*n,nb,0);
	for(j=0;j<nb;j++)	step[j]=0.01;			// starting step size 1%

	e=(Hamiltonian(psi,ur,psi,xr,n,m)).r/dotVV(expans,expans,nb);
	for(it=0;it<40*nb;it++)	{
		j=intrand(nb);
		oldex=expans[j];
		expans[j]+=binomrand(3,0,step[j]);
		sumV(1,psi,expans[j]-oldex,columnM(b,v1,2*n,nb,j),psitry,2*n);
		if((e2=(Hamiltonian(psitry,ur,psitry,xr,n,m)).r/dotVV(expans,expans,nb))<e)	{
			it=0;
			step[j]*=1.2;
			e=e2;
			copyV(psitry,psi,2*n);
			}
		else {
			expans[j]=oldex;
			step[j]*=0.93;
			}
		}
	multKV(1/sqrt(dotVV(expans,expans,nb)),psi,psi,2*n);
	freeV(v1);
	freeV(psitry);
	freeM(b);
	freeV(expans);
	freeV(step);
	return e;	
}

sptr DipTransform(float *dip,float dt,int nt,float wmin)
{
	float *vt,*vw,*vdc,*vsc,*vs,dw;
	float tmax;
	int i,nv;
	sptr ans;

	tmax=nt*dt;									// time of maximum point
	nv=nt+1;
	dip[0]/=2;
	dip[nt]/=2;
	ans=SpectAlloc("calc","","","cm-1","M-1cm-1");
	ans->n=nv;
	vt=allocV(nv);
	ans->x=vw=allocV(nv);
	vdc=allocV(nv*2);
	vsc=allocV(nv*2);
	ans->y=vs=allocV(nv);
	for(i=0;i<nv;i++) vt[i]=dt*i;					// vt is time values
	dw=PI/(nt*dt);
	for(i=0;i<nv;i++) vw[i]=wmin+dw*i;		// vw is frequency values
	makecmplx(dip,NULL,vdc,nv);						// vdc is complex dips
	fourier(vt,vdc,vw,vsc,nv,nv);					// vsc is complex spectrum
	real(vsc,vs,nv);
	for(i=0;i<nv;i++)	vs[i]*=vw[i]/(eps_SSI*c_SSI*hbar_SSI)*sqrt(2*PI);  // vs is spectrum in ^2
	for(i=0;i<nv;i++) vw[i]/=2*PI*1e-8*c_SSI;		// vw is frequency in cm^-1
	for(i=0;i<nv;i++)	vs[i]*=NA_SI/1e19;				// vs is spectrum in M^-1 cm^-1	 
	freeV(vdc);
	freeV(vt);
	freeV(vsc);
	return ans;
}



