// http://www.cppblog.com/coreBugZJ/archive/2011/03/28/142886.html
#include <iostream>
#include <cstdio>
using namespace std;
#define L 1088
#define LIM 999
#define lda L
typedef double Mat[ L ][ L ];
Mat BUF[64];
double (* (buf[1024]))[L];
int init() {
for (int i=0; i<32; ++i) buf[i] = BUF[i];
for (int i=32; i<1024; ++i) buf[i] = (512*L>>i/32)(double*)BUF[i%32+32];
return 0;}
int top = init();
void input( double a[][L], int n, double s[] ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j <= n; ++j ) {
a[ i ][ j ] =s[i*n+j];
}
}
}
void output( double c[][L], int n , double s[] ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j < n; ++j ) {
s[i*n+j]= c[ i ][ j ] ;
}
}
}
void get( double a[][L], double a11[][L], double a12[][L], double a21[][L], double a22[][L], int n ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j <= n; ++j ) {
a11[ i ][ j ] = a[ i ][ j ];
a12[ i ][ j ] = a[ i ][ j + n ];
a21[ i ][ j ] = a[ i + n ][ j ];
a22[ i ][ j ] = a[ i + n ][ j + n ];
}
}
}
void put( double a[][L], double a11[][L], double a12[][L], double a21[][L], double a22[][L], int n ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j <= n; ++j ) {
a[ i ][ j ] = a11[ i ][ j ];
a[ i ][ j + n ] = a12[ i ][ j ];
a[ i + n ][ j ] = a21[ i ][ j ];
a[ i + n ][ j + n ] = a22[ i ][ j ];
}
}
}
void add( double c[][L], double a[][L], double b[][L], int n ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j <= n; ++j ) {
c[ i ][ j ] = a[ i ][ j ] + b[ i ][ j ];
}
}
}
void sub( double c[][L], double a[][L], double b[][L], double n ) {
int i, j;
for ( i = 1; i <= n; ++i ) {
for ( j = 1; j <= n; ++j ) {
c[ i ][ j ] = a[ i ][ j ] - b[ i ][ j ];
}
}
}
void mul( double c[][L], double a[][L], double b[][L], int n ) {
#define ADD(m) Mat &m = buf[ top++ ]
#define ADDS(a) ADD(a##11); ADD(a##12); ADD(a##21); ADD(a##22)
#define ENTER ADDS(a); ADDS(b); ADDS(c); ADD(d1); ADD(d2); ADD(d3); ADD(d4); ADD(d5); ADD(d6); ADD(d7); ADD(t1); ADD(t2)
#define LEAVE top -= 21
ENTER;
/*
if ( top >= LIM ) {
// for debug
fprintf( stderr, "buf overflow!!" );
LEAVE;
return;
} */
if ( n < 1 ) {
LEAVE;
return;
}
if ( n == 1 ) {
c[ 1 ][ 1 ] = a[ 1 ][ 1 ] * b[ 1 ][ 1 ];
LEAVE;
return;
}
n >>= 1;
get( a, a11, a12, a21, a22, n );
get( b, b11, b12, b21, b22, n );
add( t1, a11, a22, n );
add( t2, b11, b22, n );
mul( d1, t1, t2, n );
add( t1, a21, a22, n );
mul( d2, t1, b11, n );
sub( t2, b12, b22, n );
mul( d3, a11, t2, n );
sub( t2, b21, b11, n );
mul( d4, a22, t2, n );
add( t1, a11, a12, n );
mul( d5, t1, b22, n );
sub( t1, a21, a11, n );
add( t2, b11, b12, n );
mul( d6, t1, t2, n );
sub( t1, a12, a22, n );
add( t2, b21, b22, n );
mul( d7, t1, t2, n );
add( t1, d1, d4, n );
sub( t2, d5, d7, n );
sub( c11, t1, t2, n );
add( c12, d3, d5, n );
add( c21, d2, d4, n );
add( t1, d1, d3, n );
sub( t2, d2, d6, n );
sub( c22, t1, t2, n );
put( c, c11, c12, c21, c22, n );
LEAVE;
}
void matrix_multiply(int n, const double* _A, const double* _B, double* _C) {
double A[lda ][ lda], B[lda ][ lda], C[lda ][ lda] __attribute__((aligned(4096)));
input(A, n, (double*)_A);
input(B, n, (double*)_B);
mul( C, A, B, n );
output(C, n, _C);
}
| Compilation | N/A | N/A | Compile Error | Score: N/A | 显示更多 |