Friday, April 22, 2005

LDLTDecomposition.java

I have been struggling with remove() for a few weeks. The original C code from ODE is very difficult to understand. I finally gave up writing efficient code yesterday by only supplementing with some naive code. So now this class LDLTDecomposition is temporarily finished until somehow I learned a more efficient implementation. Function remove() is called by transfer_i_from_C_to_N() in class FastLCP. I compared FastLCP and SlowLCP, and found the test results match precisely. Nice. I finally got it working. I was even happier to find out that FastLCP did run faster than SlowLCP. Now I can move on to finish step() in class World with some confidence.
package lcp;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Property;

public class LDLTDecomposition implements java.io.Serializable {
static final long serialVersionUID = 1020;
/** Array for internal storage of decomposition.
@serial internal array storage.
*/


private DoubleMatrix2D L;
private DoubleMatrix1D d;

cern.jet.math.Functions
F = cern.jet.math.Functions.functions;

private int nc; // number of clamped DOF

private int n; // DOF

/** Symmetric and positive definite flag.
@serial is symmetric and positive definite flag.
*/

private boolean isSymmetric;

//cern.jet.math.Functions F = cern.jet.math.Functions.functions;

public LDLTDecomposition(DoubleMatrix2D A) {
this(A, A.rows());
}

/**
Constructs and returns a new LDL' decomposition object for a symmetric matrix;
The decomposed matrices can be retrieved via instance methods of the returned decomposition object.

@param A Square, symmetric matrix.
@return Structure to access <tt>L</tt> and <tt>isSymmetric</tt> flag.
@throws IllegalArgumentException if <tt>A</tt> is not square.
*/


public LDLTDecomposition(DoubleMatrix2D A, int nc) {
Property.DEFAULT.checkSquare(A);

this.nc = nc;
this.n = A.rows();
isSymmetric = (A.columns() == n);
if(nc > n)
throw new IllegalArgumentException("n must be less than rows of A");

L = A.like();
d = A.like1D(n);

if (n == 0) return;

// Main loop.

for (int j = 0; j < n; j++) {
for (int k = 0; k < j; k++) {
double sum = A.getQuick(j, k);
for (int i = k - 1; i >= 0; i--)
sum -=
L.getQuick(j, i) * d.getQuick(i) * L.getQuick(k, i);
L.setQuick(j, k, sum / d.getQuick(k));
isSymmetric = isSymmetric && (A.getQuick(k, j) == A.getQuick(j, k));
}
double sum = A.getQuick(j, j);
for (int i = j - 1; i >= 0; i--)
sum -=
L.getQuick(j, i) * d.getQuick(i) * L.getQuick(j, i);
d.setQuick(j, sum);
}
}

public int getnc() { return nc; }
public void setnc(int n) { this.nc=n; }

/**
Returns the triangular factor, <tt>L</tt>.
@return <tt>L</tt>

*/

public DoubleMatrix2D getL() { return L; }

/**
Returns the diagonal, <tt>d</tt>.
@return <tt>d</tt>

*/

public DoubleMatrix1D getd() { return d; }

/**
Returns whether the matrix <tt>A</tt> is symmetric.
@return true if <tt>A</tt> is symmetric; false otherwise
*/


public boolean isSymmetric() { return isSymmetric; }

/**
Solves <tt>A*X = B</tt>; returns <tt>X</tt>.
@param B A Matrix with as many rows as <tt>A</tt> and any number of columns.
@return <tt>X</tt> so that <tt>L*D*L'*X = B</tt>.
@exception IllegalArgumentException if <tt>B.rows() != A.rows()</tt>.
@exception IllegalArgumentException if <tt>!isSymmetric()</tt>.
*/


public DoubleMatrix2D solve(DoubleMatrix2D B) {
// directly operates on B
DoubleMatrix2D X = B;
int nx = B.columns();

for (int c = 0; c < nx; c++) {
// Solve L*Y = B;

for (int i = 0; i < nc; i++) {
double sum = B.getQuick(i, c);
for (int k = i - 1; k >= 0; k--) {
sum -=
L.getQuick(i, k) * X.getQuick(k, c);
}
X.setQuick(i, c, sum);
}
// Solve D*Z = Y;

for (int i = 0; i < nc; i++)
X.setQuick(i, c, X.getQuick(i, c) /
d.getQuick(i));
// Solve L'*X = Z;
for (int i = nc - 1; i >= 0; i--) {
double sum = X.getQuick(i, c);
for (int k = i + 1; k < nc; k++) {
sum -=
L.getQuick(k, i) * X.getQuick(k, c);
}
X.setQuick(i, c, sum);
}
}
return X;
}

/**
Solves <tt>A*x = b</tt>; returns <tt>x</tt>.
@param b A vector with as many rows as <tt>A</tt>.
@return <tt>x</tt> so that <tt>L*D*L'*x = b</tt>.
*/


public DoubleMatrix1D solve(DoubleMatrix1D b) {
// directly operates on b
DoubleMatrix1D x = b;
// Solve L*Y = B
solveL(x,
nc);
// Solve D*Z = Y
x.viewPart(
0,nc).assign(d.viewPart(0,nc), F.div);
// Solve L'*X = Z

solveLT(x,
nc);
return x;
}

public DoubleMatrix1D solveL(DoubleMatrix1D x, int len) {
for (int i = 0; i < len; i++) {
double sum = x.getQuick(i);
for (int k = i - 1; k >= 0; k--) {
sum -=
L.getQuick(i, k) * x.getQuick(k);
}
x.setQuick(i, sum);
}
return x;
}

public DoubleMatrix1D solveLT(DoubleMatrix1D x, int len) {
for (int i = len - 1; i >= 0; i--) {
double sum = x.getQuick(i);
for (int k = i + 1; k < len; k++) {
sum -=
L.getQuick(k, i) * x.getQuick(k);
}
x.setQuick(i, sum);
}
return x;
}

public void remove(DoubleMatrix2D A, int[] p, int r) throws LCPException {
if(r>nc-1)
throw new LCPException("The row to be removed is out of bound!");
else if(r<nc-1) {
for (int j = r+1; j < nc; j++) {
for (int k = r+1; k < j; k++) {
double sum = A.getQuick(p[j], p[k]);
for (int m=0; m<k; m++)
if(m!=r) sum -= L.getQuick(j, m) * d.getQuick(m) * L.getQuick(k, m);
L.setQuick(j, k, sum / d.getQuick(k));
}
double sum = A.getQuick(p[j], p[j]);
for (int m=0; m<j; m++)
if(m!=r) sum -= L.getQuick(j, m) * d.getQuick(m) * L.getQuick(j, m);
d.setQuick(j, sum);
}
}
L.viewPart(r,0,nc-1-r,nc).assign(L.viewPart(r+1,0,nc-1-r,nc));
L.viewPart(0,r,nc,nc-1-r).assign(L.viewPart(0,r+1,nc,nc-1-r));
d.viewPart(r,nc-r-1).assign(d.viewPart(r+1,nc-r-1));
nc--;
}

public String toString() {
StringBuffer buf =
new StringBuffer();
String unknown =
"Illegal operation or error: ";

buf.append(
"----------------------------------------------------------\n");
buf.append(
"LDLTDecomposition(A) --> isSymmetric(A), L, d, inverse(A)\n");
buf.append(
"----------------------------------------------------------\n");

buf.append(
"isSymmetric = ");
try { buf.append(String.valueOf(this.isSymmetric()));}
catch (IllegalArgumentException exc) { buf.append(unknown+exc.getMessage()); }

buf.append(
"\n\nL = ");
try { buf.append(String.valueOf(this.getL()));}
catch (IllegalArgumentException exc) { buf.append(unknown+exc.getMessage()); }

buf.append(
"\n\nd = ");
try { buf.append(String.valueOf(this.getd()));}
catch (IllegalArgumentException exc) { buf.append(unknown+exc.getMessage()); }

buf.append(
"\n\ninverse(A) = ");
try { buf.append(String.valueOf(this.solve(cern.colt.matrix.DoubleFactory2D.dense.identity(nc))));}
catch (IllegalArgumentException exc) { buf.append(unknown+exc.getMessage()); }

return buf.toString();
}
}