#include <bla.hpp>

namespace ngbla
{
  using namespace ngbla;
  

  template <int N, int N2, typename SCAL>
  inline double abs (Mat<N,N2,SCAL> & m)
  {
    double sum = 0;
    for (int i = 0; i < N; i++)
      sum += abs(m(i,i));
    return sum;
  }
  
  inline double abs (double a)
  {
    return fabs(a);
  }

  inline double abs (Complex a)
  {
    return std::abs(a);
  }


  template <class T, class T2>
  void CalcInverse (const FlatMatrix<T> m, FlatMatrix<T2> inv)
  {
    //	  Gauss - Jordan - algorithm
  
    int n = m.Height();

    int r;
    double maxval;
    T hr;

    ngstd::ARRAY<int> p(n);   // pivot-permutation
    Vector<T> hv(n);
   
    inv = m;

    // Algorithm of Stoer, Einf. i. d. Num. Math, S 145

    for (int j = 0; j < n; j++)
      p[j] = j;
    
    for (int j = 0; j < n; j++)
      {
	// pivot search

	maxval = abs(inv(j,j));
	r = j;

	for (int i = j+1; i < n ;i++)
	  if (abs (inv(i, j)) > maxval)
	    {
	      r = i;
	      maxval = abs (inv(i, j));
	    }
      
	if (maxval < 1e-20)
	  {
	    throw Exception ("Inverse matrix: Matrix singular");
	  }

	// exchange rows
	if (r > j)
	  {
	    for (int k = 0; k < n; k++)
	      swap (inv(j,k), inv(r,k));
	    swap (p[j], p[r]);
	  }
      

	// transformation

	CalcInverse (inv(j,j), hr);
	for (int i = 0; i < n; i++)
	  {
	    T  h = inv(i,j) * hr;
	    inv(i,j) = h;
	  }
	inv(j,j) = hr;


	for (int k = 0; k < n; k++)
	  if (k != j)
	    {
	      
	      /*
	      for (i = 0; i < n; i++)
		if (i != j)
		  {
		    T h = inv(n*i+j) * inv(n*j+k);
		    inv(n*i+k) -= h;
		  }
	      */
	      T help = inv(n*j+k);
	      for (int i = 0; i < j; i++)
		{
		  T h = inv(n*i+j) * help; // inv(n*j+k);
		  inv(n*i+k) -= h;
		}
	      for (int i = j+1; i < n; i++)
		{
		  T h = inv(n*i+j) * help; // inv(n*j+k);
		  inv(n*i+k) -= h;
		}

	      T h = hr * inv(j,k);   
	      inv(j,k) = -h;
	    }
      }

    // col exchange
  
    for (int i = 0; i < n; i++)
      {
	for (int k = 0; k < n; k++)
	  hv(p[k]) = inv(i, k);
	for (int k = 0; k < n; k++)
	  inv(i, k) = hv(k);
      }
  }

  template void CalcInverse (const FlatMatrix<double> m, 
			     FlatMatrix<double> inv);
  template void CalcInverse (const FlatMatrix<Mat<1,1,double> > m, 
			     FlatMatrix<Mat<1,1,double> > inv);

  template void CalcInverse (const FlatMatrix<Mat<2,2,double> > m, 
			     FlatMatrix<Mat<2,2,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<3,3,double> > m, 
			     FlatMatrix<Mat<3,3,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<4,4,double> > m, 
			     FlatMatrix<Mat<4,4,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<5,5,double> > m, 
			     FlatMatrix<Mat<5,5,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<6,6,double> > m, 
			     FlatMatrix<Mat<6,6,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<7,7,double> > m, 
			     FlatMatrix<Mat<7,7,double> > inv);
  template void CalcInverse (const FlatMatrix<Mat<8,8,double> > m, 
			     FlatMatrix<Mat<8,8,double> > inv);




  template void CalcInverse (const FlatMatrix<Complex> m, 
			     FlatMatrix<Complex> inv);
  template void CalcInverse (const FlatMatrix<Mat<1,1,Complex> > m, 
			     FlatMatrix<Mat<1,1,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<2,2,Complex> > m, 
			     FlatMatrix<Mat<2,2,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<3,3,Complex> > m, 
			     FlatMatrix<Mat<3,3,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<4,4,Complex> > m, 
			     FlatMatrix<Mat<4,4,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<5,5,Complex> > m, 
			     FlatMatrix<Mat<5,5,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<6,6,Complex> > m, 
			     FlatMatrix<Mat<6,6,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<7,7,Complex> > m, 
			     FlatMatrix<Mat<7,7,Complex> > inv);
  template void CalcInverse (const FlatMatrix<Mat<8,8,Complex> > m, 
			     FlatMatrix<Mat<8,8,Complex> > inv);
}
