/* Steven Andrews, 2/20/93.  Modified substantially 9/98.	*/
/* See documentation called RnLU doc */
/* Copyright 2003 by Steven Andrews.  Permission is granted
   for non-commercial use of and modifications to the code. */

#include <stdlib.h>
#include <math.h>
#include "math2.h"
#include "Rn.h"
#include "RnLU.h"

float LUdecomp(float *a,int n,int **indxptr) {
	int i,j,k,imax,*indx;
	double big,temp,dum,sum;
	float *vv;

	indx=*indxptr=NULL;
	vv=allocV(n);																	/* row scaling vector */
	if(!vv) goto failure;
	*indxptr=indx=(int *) calloc(n+1,sizeof(int));				/* row index	*/
	if(!indx) goto failure;
	indx[n]=1;																		/* even number row swaps */
	for(i=0;i<n;i++)	{														/* scaling information */
		big=0.0;
		for(j=0;j<n;j++)
			if((temp=fabs(a[n*i+j]))>big) big=temp;
		if(big==0.0) goto failure;											/* singular */
		vv[i]=1.0/big; }
	imax=0;
	for(j=0;j<n;j++)	{
		for(i=0;i<j;i++)	{
			sum=a[n*i+j];
			for(k=0;k<i;k++) sum-=a[n*i+k]*a[n*k+j];
			a[n*i+j]=sum; }
		big=0.0;																		/* biggest pivot so far */
		for(i=j;i<n;i++)	{
			sum=a[n*i+j];
			for(k=0;k<j;k++) sum-=a[n*i+k]*a[n*k+j];
			a[n*i+j]=sum;
			if((dum=vv[i]*fabs(sum))>=big)	{
				big=dum;
				imax=i; }}
		if(j!=imax)	{
			for(k=0;k<n;k++)	{
				dum=a[n*imax+k];
				a[n*imax+k]=a[n*j+k];
				a[n*j+k]=dum; }
			indx[n]=-indx[n];
			vv[imax]=vv[j]; }
		indx[j]=imax;
		if(a[n*j+j]==0.0) goto failure;									/* singular */
		dum=1.0/a[n*j+j];
		for(i=j+1;i<n;i++) a[n*i+j]*=dum; }
	freeV(vv);
	sum=indx[n];
	for(i=0;i<n;i++)
		sum*=a[n*i+i];
	return sum;
	
 failure:
 	if(vv) freeV(vv);
 	if(indx) free(indx);
 	*indxptr=NULL;
 	return 0; }

void LUsolveV(float *a,int *indx,float *b,int n) {
	int i,ii=-1,j,ip;
	double sum;

	for(i=0;i<n;i++)	{
		ip=indx[i];
		sum=b[ip];
		b[ip]=b[i];
		if(ii+1)
			for(j=ii;j<=i-1;j++) sum-=a[n*i+j]*b[j];
		else if(sum) ii=i;
		b[i]=sum; }
	for(i=n-1;i>=0;i--)	{
		sum=b[i];
		for(j=i+1;j<n;j++) sum-=a[n*i+j]*b[j];
		b[i]=sum/a[n*i+i]; }}

void LUsolveM(float *a,int *indx,float *b,int n,int nb) {
	int i,j,k,ip,ii;
	double sum;

	for(k=0;k<nb;k++)	{
		ii=-1;
		for(i=0;i<n;i++)	{
			ip=indx[i];
			sum=b[nb*ip+k];
			b[nb*ip+k]=b[nb*i+k];
			if(ii+1)
				for(j=ii;j<=i-1;j++) sum-=a[n*i+j]*b[nb*j+k];
			else if(sum) ii=i;
			b[nb*i+k]=sum; }
		for(i=n-1;i>=0;i--)	{
			sum=b[nb*i+k];
			for(j=i+1;j<n;j++) sum-=a[n*i+j]*b[nb*j+k];
			b[nb*i+k]=sum/a[n*i+i]; }}}
	
float LUlndetM(float *a,int *indx,int n,int *sgn) {
	int i,s;
	double p=0;
	
	s=indx[n];
	for(i=0;i<n;i++)	{
		p+=log(fabs(a[n*i+i]));
		s*=sign(a[n*i+i]); }
	if(sgn) *sgn=s;
	return p; }

float LUimprV(float *a,float *alu,int *indx,float *x,float *b,int n) {
	float *z;
	double xlen,e;

	z=allocV(n);
	if(!z) return 0;
	dotMV(a,x,z,n,n);
	sumV(1,z,-1,b,z,n);
	LUsolveV(alu,indx,z,n);
	sumV(1,x,-1,z,x,n);
	xlen=dotVV(x,x,n);
	if(!xlen) xlen=1.0;
	e=sqrt(dotVV(z,z,n)/xlen);
	freeV(z);
	return e; }

float LUimprM(float *a,float *alu,int *indx,float *x,float *b,int n,int nb) {
	float *z;
	double xlen,e;

	z=allocM(n,nb);
	if(!z) return 0;
	dotMM(a,x,z,n,n,nb);
	sumM(1,z,-1,b,z,n,nb);
	LUsolveM(alu,indx,z,n,nb);
	sumM(1,x,-1,z,x,n,nb);
	xlen=dotVV(x,x,n*nb);
	if(!xlen) xlen=1.0;
	e=sqrt(dotVV(z,z,n*nb)/xlen);
	freeM(z);
	return e; }

