#include<bits/stdc++.h>
#define li unsigned long long
#define gc getchar()
#define pc putchar
using namespace std;
li s1 = 19260817,s2 = 23333333,s3 = 998244853,srd;
inline li rd(){
return srd = (srd * s1 + s2 + rand()) % s3;
}
inline li read(){
li x = 0;
int c = gc;
while(c < '0' || c > '9') c = gc;
while(c >= '0' && c <= '9') x = x * 10 + c - '0',c = gc;
return x;
}
inline li print(li x){
if(x >= 10) print(x / 10);
pc(x % 10 + '0');
}
int n,m;
struct mtx{
vector<vector<li> > a;
inline mtx(){a.clear();}
inline vector<li>& operator [] (int x){return a[x];}
inline void init(int x = 0){
a.resize(x);
for(int i = 0;i < x;++i){
a[i].resize(x);
for(int j = 0;j < x;++j) a[i][j] = 0;
}
}
inline int getn(){return a.size();}
inline void in(){
int x = a.size();
for(int i = 0;i < x;++i){
for(int j = 0;j < x;++j) a[i][j] = rd() * rd();
}
}
inline void out(){
int x = a.size();
for(int i = 0;i < x;++i){
for(int j = 0;j < x;++j) print(a[i][j]),pc(' ');
pc('\n');
}
}
inline void in(int x){
for(int i = 0;i < x;++i){
for(int j = 0;j < x;++j) a[i][j] = rd() * rd();
}
}
inline void out(int x){
for(int i = 0;i < x;++i){
for(int j = 0;j < x;++j) print(a[i][j]),pc(' ');
pc('\n');
}
}
}a,b;
inline mtx operator + (mtx x,mtx y){
int n = x.getn();
mtx as;as.init(n);
for(int i = 0;i < n;++i)
for(int j = 0;j < n;++j)
as[i][j] = x[i][j] + y[i][j];
return as;
}
inline mtx operator - (mtx x,mtx y){
int n = x.getn();
mtx as;as.init(n);
for(int i = 0;i < n;++i)
for(int j = 0;j < n;++j)
as[i][j] = x[i][j] - y[i][j];
return as;
}
inline mtx tms(mtx x,mtx y){
int n = x.getn(),i,j,k;
mtx as;
as.init(n);
for(i = 0;i < n;++i){
for(k = 0;k < n;++k){
li p = x[i][k];
for(j = 0;j < n;++j)
as[i][j] += p * y[k][j];
}
}
return as;
}
inline mtx operator * (mtx x,mtx y){
int n = x.getn();
if(n <= 256) return tms(x,y);
mtx as;as.init(n);
if(n == 1){
as[0][0] = x[0][0] * y[0][0];
return as;
}
int pn = n >> 1,i,j,k;
mtx a[2][2],b[2][2],p[5],c[2][2];
for(i = 0;i < 2;++i){
for(j = 0;j < 2;++j){
a[i][j].init(pn);
b[i][j].init(pn);
c[i][j].init(pn);
}
}
for(i = 0;i < pn;++i){
for(j = 0;j < pn;++j){
a[0][0][i][j] = x[i][j];
b[0][0][i][j] = y[i][j];
a[0][1][i][j] = x[i][j + pn];
b[0][1][i][j] = y[i][j + pn];
a[1][0][i][j] = x[i + pn][j];
b[1][0][i][j] = y[i + pn][j];
a[1][1][i][j] = x[i + pn][j + pn];
b[1][1][i][j] = y[i + pn][j + pn];
}
}
p[0] = a[0][0] * (b[0][1] - b[1][1]);
p[1] = (a[0][0] + a[0][1]) * b[1][1];
p[2] = (a[1][0] + a[1][1]) * b[0][0];
p[3] = a[1][1] * (b[1][0] - b[0][0]);
p[4] = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
c[0][0] = p[4] + p[3] - p[1] + (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
c[0][1] = p[0] + p[1];
c[1][0] = p[2] + p[3];
c[1][1] = p[4] + p[0] - p[2] - (a[0][0] - a[1][0]) * (b[0][0] + b[0][1]);
for(i = 0;i < pn;++i){
for(j = 0;j < pn;++j){
as[i][j] = c[0][0][i][j];
as[i][j + pn] = c[0][1][i][j];
as[i + pn][j] = c[1][0][i][j];
as[i + pn][j + pn] = c[1][1][i][j];
}
}
return as;
}
mtx as1,as2;
void chk(){
int i,j;
for(i = 0;i < n;++i)
for(j = 0;j < n;++j)
assert(as1[i][j] == as2[i][j]);
}
int main(){
srand(time(0));rd();
for(int qwq = 1;qwq <= 100;++qwq){
n = 1024;
//n = rd() % 256 + 1;
for(m = 1;m < n;m <<= 1);
a.init(m);b.init(m);
//cerr<<qwq<<":::"<<n<<endl;
a.in(n);b.in(n);
as1 = a * b;
return 0;
int d = clock();
as1 = a * b;
cerr<<"tms1 "<<clock() - d<<endl;
d = clock();
as2 = tms(a,b);
cerr<<"tms2 "<<clock() - d<<endl;
chk();
}
return 0;
}
Compilation | N/A | N/A | Compile OK | Score: N/A | 显示更多 |
Testcase #1 | 615.448 ms | 103 MB + 1020 KB | Accepted | Score: 100 | 显示更多 |