Actual source code: baijfact81.c


  2: /*
  3:    Factorization code for BAIJ format.
  4:  */
  5: #include <../src/mat/impls/baij/seq/baij.h>
  6: #include <petsc/private/kernels/blockinvert.h>
  7: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
  8: #include <immintrin.h>
  9: #endif
 10: /*
 11:    Version for when blocks are 9 by 9
 12:  */
 13: #if defined(PETSC_HAVE_IMMINTRIN_H) && defined(__AVX2__) && defined(__FMA__) && defined(PETSC_USE_REAL_DOUBLE) && !defined(PETSC_USE_COMPLEX) && !defined(PETSC_USE_64BIT_INDICES)
 14: PetscErrorCode MatLUFactorNumeric_SeqBAIJ_9_NaturalOrdering(Mat B,Mat A,const MatFactorInfo *info)
 15: {
 16:   Mat            C =B;
 17:   Mat_SeqBAIJ    *a=(Mat_SeqBAIJ*)A->data,*b=(Mat_SeqBAIJ*)C->data;
 18:   PetscInt       i,j,k,nz,nzL,row;
 19:   const PetscInt n=a->mbs,*ai=a->i,*aj=a->j,*bi=b->i,*bj=b->j;
 20:   const PetscInt *ajtmp,*bjtmp,*bdiag=b->diag,*pj,bs2=a->bs2;
 21:   MatScalar      *rtmp,*pc,*mwork,*v,*pv,*aa=a->a;
 22:   PetscInt       flg;
 23:   PetscReal      shift = info->shiftamount;
 24:   PetscBool      allowzeropivot,zeropivotdetected;

 26:   allowzeropivot = PetscNot(A->erroriffailure);

 28:   /* generate work space needed by the factorization */
 29:   PetscMalloc2(bs2*n,&rtmp,bs2,&mwork);
 30:   PetscArrayzero(rtmp,bs2*n);

 32:   for (i=0; i<n; i++) {
 33:     /* zero rtmp */
 34:     /* L part */
 35:     nz    = bi[i+1] - bi[i];
 36:     bjtmp = bj + bi[i];
 37:     for  (j=0; j<nz; j++) {
 38:       PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
 39:     }

 41:     /* U part */
 42:     nz    = bdiag[i] - bdiag[i+1];
 43:     bjtmp = bj + bdiag[i+1]+1;
 44:     for  (j=0; j<nz; j++) {
 45:       PetscArrayzero(rtmp+bs2*bjtmp[j],bs2);
 46:     }

 48:     /* load in initial (unfactored row) */
 49:     nz    = ai[i+1] - ai[i];
 50:     ajtmp = aj + ai[i];
 51:     v     = aa + bs2*ai[i];
 52:     for (j=0; j<nz; j++) {
 53:       PetscArraycpy(rtmp+bs2*ajtmp[j],v+bs2*j,bs2);
 54:     }

 56:     /* elimination */
 57:     bjtmp = bj + bi[i];
 58:     nzL   = bi[i+1] - bi[i];
 59:     for (k=0; k < nzL; k++) {
 60:       row = bjtmp[k];
 61:       pc  = rtmp + bs2*row;
 62:       for (flg=0,j=0; j<bs2; j++) {
 63:         if (pc[j]!=0.0) {
 64:           flg = 1;
 65:           break;
 66:         }
 67:       }
 68:       if (flg) {
 69:         pv = b->a + bs2*bdiag[row];
 70:         /* PetscKernel_A_gets_A_times_B(bs,pc,pv,mwork); *pc = *pc * (*pv); */
 71:         PetscKernel_A_gets_A_times_B_9(pc,pv,mwork);

 73:         pj = b->j + bdiag[row+1]+1; /* beginning of U(row,:) */
 74:         pv = b->a + bs2*(bdiag[row+1]+1);
 75:         nz = bdiag[row] - bdiag[row+1] - 1; /* num of entries inU(row,:), excluding diag */
 76:         for (j=0; j<nz; j++) {
 77:           /* PetscKernel_A_gets_A_minus_B_times_C(bs,rtmp+bs2*pj[j],pc,pv+bs2*j); */
 78:           /* rtmp+bs2*pj[j] = rtmp+bs2*pj[j] - (*pc)*(pv+bs2*j) */
 79:           v    = rtmp + bs2*pj[j];
 80:           PetscKernel_A_gets_A_minus_B_times_C_9(v,pc,pv+81*j);
 81:           /* pv incremented in PetscKernel_A_gets_A_minus_B_times_C_9 */
 82:         }
 83:         PetscLogFlops(1458*nz+1377); /* flops = 2*bs^3*nz + 2*bs^3 - bs2) */
 84:       }
 85:     }

 87:     /* finished row so stick it into b->a */
 88:     /* L part */
 89:     pv = b->a + bs2*bi[i];
 90:     pj = b->j + bi[i];
 91:     nz = bi[i+1] - bi[i];
 92:     for (j=0; j<nz; j++) {
 93:       PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
 94:     }

 96:     /* Mark diagonal and invert diagonal for simpler triangular solves */
 97:     pv   = b->a + bs2*bdiag[i];
 98:     pj   = b->j + bdiag[i];
 99:     PetscArraycpy(pv,rtmp+bs2*pj[0],bs2);
100:     PetscKernel_A_gets_inverse_A_9(pv,shift,allowzeropivot,&zeropivotdetected);
101:     if (zeropivotdetected) C->factorerrortype = MAT_FACTOR_NUMERIC_ZEROPIVOT;

103:     /* U part */
104:     pv = b->a + bs2*(bdiag[i+1]+1);
105:     pj = b->j + bdiag[i+1]+1;
106:     nz = bdiag[i] - bdiag[i+1] - 1;
107:     for (j=0; j<nz; j++) {
108:       PetscArraycpy(pv+bs2*j,rtmp+bs2*pj[j],bs2);
109:     }
110:   }
111:   PetscFree2(rtmp,mwork);

113:   C->ops->solve          = MatSolve_SeqBAIJ_9_NaturalOrdering;
114:   C->ops->solvetranspose = MatSolveTranspose_SeqBAIJ_N;
115:   C->assembled           = PETSC_TRUE;

117:   PetscLogFlops(1.333333333333*9*9*9*n); /* from inverting diagonal blocks */
118:   return 0;
119: }

121: PetscErrorCode MatSolve_SeqBAIJ_9_NaturalOrdering(Mat A,Vec bb,Vec xx)
122: {
123:   Mat_SeqBAIJ    *a=(Mat_SeqBAIJ*)A->data;
124:   const PetscInt *ai=a->i,*aj=a->j,*adiag=a->diag,*vi;
125:   PetscInt       i,k,n=a->mbs;
126:   PetscInt       nz,bs=A->rmap->bs,bs2=a->bs2;
127:   const MatScalar   *aa=a->a,*v;
128:   PetscScalar       *x,*s,*t,*ls;
129:   const PetscScalar *b;
130:   __m256d a0,a1,a2,a3,a4,a5,w0,w1,w2,w3,s0,s1,s2,v0,v1,v2,v3;

132:   VecGetArrayRead(bb,&b);
133:   VecGetArray(xx,&x);
134:   t    = a->solve_work;

136:   /* forward solve the lower triangular */
137:   PetscArraycpy(t,b,bs); /* copy 1st block of b to t */

139:   for (i=1; i<n; i++) {
140:     v    = aa + bs2*ai[i];
141:     vi   = aj + ai[i];
142:     nz   = ai[i+1] - ai[i];
143:     s    = t + bs*i;
144:     PetscArraycpy(s,b+bs*i,bs); /* copy i_th block of b to t */

146:     __m256d s0,s1,s2;
147:     s0 = _mm256_loadu_pd(s+0);
148:     s1 = _mm256_loadu_pd(s+4);
149:     s2 = _mm256_maskload_pd(s+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));

151:     for (k=0;k<nz;k++) {

153:       w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
154:       a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
155:       a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
156:       a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

158:       w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
159:       a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
160:       a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
161:       a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

163:       w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
164:       a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
165:       a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
166:       a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

168:       w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
169:       a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
170:       a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
171:       a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

173:       w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
174:       a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
175:       a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
176:       a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

178:       w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
179:       a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
180:       a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
181:       a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

183:       w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
184:       a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
185:       a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
186:       a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

188:       w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
189:       a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
190:       a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
191:       a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

193:       w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
194:       a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
195:       a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
196:       a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
197:       s2 = _mm256_fnmadd_pd(a2,w0,s2);
198:       v += bs2;
199:     }
200:          _mm256_storeu_pd(&s[0], s0);
201:          _mm256_storeu_pd(&s[4], s1);
202:          _mm256_maskstore_pd(&s[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);
203:   }

205:   /* backward solve the upper triangular */
206:   ls = a->solve_work + A->cmap->n;
207:   for (i=n-1; i>=0; i--) {
208:     v    = aa + bs2*(adiag[i+1]+1);
209:     vi   = aj + adiag[i+1]+1;
210:     nz   = adiag[i] - adiag[i+1]-1;
211:     PetscArraycpy(ls,t+i*bs,bs);

213:     s0 = _mm256_loadu_pd(ls+0);
214:     s1 = _mm256_loadu_pd(ls+4);
215:     s2 = _mm256_maskload_pd(ls+8, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));

217:     for (k=0; k<nz; k++) {

219:       w0 = _mm256_set1_pd((t+bs*vi[k])[0]);
220:       a0 = _mm256_loadu_pd(&v[ 0]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
221:       a1 = _mm256_loadu_pd(&v[ 4]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
222:       a2 = _mm256_loadu_pd(&v[ 8]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

224:       /* v += 9; */
225:       w1 = _mm256_set1_pd((t+bs*vi[k])[1]);
226:       a3 = _mm256_loadu_pd(&v[ 9]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
227:       a4 = _mm256_loadu_pd(&v[13]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
228:       a5 = _mm256_loadu_pd(&v[17]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

230:       /* v += 9; */
231:       w2 = _mm256_set1_pd((t+bs*vi[k])[2]);
232:       a0 = _mm256_loadu_pd(&v[18]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
233:       a1 = _mm256_loadu_pd(&v[22]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
234:       a2 = _mm256_loadu_pd(&v[26]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

236:       /* v += 9; */
237:       w3 = _mm256_set1_pd((t+bs*vi[k])[3]);
238:       a3 = _mm256_loadu_pd(&v[27]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
239:       a4 = _mm256_loadu_pd(&v[31]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
240:       a5 = _mm256_loadu_pd(&v[35]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

242:       /* v += 9; */
243:       w0 = _mm256_set1_pd((t+bs*vi[k])[4]);
244:       a0 = _mm256_loadu_pd(&v[36]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
245:       a1 = _mm256_loadu_pd(&v[40]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
246:       a2 = _mm256_loadu_pd(&v[44]); s2 = _mm256_fnmadd_pd(a2,w0,s2);

248:       /* v += 9; */
249:       w1 = _mm256_set1_pd((t+bs*vi[k])[5]);
250:       a3 = _mm256_loadu_pd(&v[45]); s0 = _mm256_fnmadd_pd(a3,w1,s0);
251:       a4 = _mm256_loadu_pd(&v[49]); s1 = _mm256_fnmadd_pd(a4,w1,s1);
252:       a5 = _mm256_loadu_pd(&v[53]); s2 = _mm256_fnmadd_pd(a5,w1,s2);

254:       /* v += 9; */
255:       w2 = _mm256_set1_pd((t+bs*vi[k])[6]);
256:       a0 = _mm256_loadu_pd(&v[54]); s0 = _mm256_fnmadd_pd(a0,w2,s0);
257:       a1 = _mm256_loadu_pd(&v[58]); s1 = _mm256_fnmadd_pd(a1,w2,s1);
258:       a2 = _mm256_loadu_pd(&v[62]); s2 = _mm256_fnmadd_pd(a2,w2,s2);

260:       /* v += 9; */
261:       w3 = _mm256_set1_pd((t+bs*vi[k])[7]);
262:       a3 = _mm256_loadu_pd(&v[63]); s0 = _mm256_fnmadd_pd(a3,w3,s0);
263:       a4 = _mm256_loadu_pd(&v[67]); s1 = _mm256_fnmadd_pd(a4,w3,s1);
264:       a5 = _mm256_loadu_pd(&v[71]); s2 = _mm256_fnmadd_pd(a5,w3,s2);

266:       /* v += 9; */
267:       w0 = _mm256_set1_pd((t+bs*vi[k])[8]);
268:       a0 = _mm256_loadu_pd(&v[72]); s0 = _mm256_fnmadd_pd(a0,w0,s0);
269:       a1 = _mm256_loadu_pd(&v[76]); s1 = _mm256_fnmadd_pd(a1,w0,s1);
270:       a2 = _mm256_maskload_pd(v+80, _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
271:       s2 = _mm256_fnmadd_pd(a2,w0,s2);
272:       v += bs2;
273:     }

275:          _mm256_storeu_pd(&ls[0], s0); _mm256_storeu_pd(&ls[4], s1); _mm256_maskstore_pd(&ls[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), s2);

277:     w0 = _mm256_setzero_pd(); w1 = _mm256_setzero_pd(); w2 = _mm256_setzero_pd();

279:     /* first row */
280:     v0 = _mm256_set1_pd(ls[0]);
281:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[0]); w0 = _mm256_fmadd_pd(a0,v0,w0);
282:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[4]); w1 = _mm256_fmadd_pd(a1,v0,w1);
283:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[8]); w2 = _mm256_fmadd_pd(a2,v0,w2);

285:     /* second row */
286:     v1 = _mm256_set1_pd(ls[1]);
287:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[9]); w0 = _mm256_fmadd_pd(a3,v1,w0);
288:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[13]); w1 = _mm256_fmadd_pd(a4,v1,w1);
289:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[17]); w2 = _mm256_fmadd_pd(a5,v1,w2);

291:     /* third row */
292:     v2 = _mm256_set1_pd(ls[2]);
293:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[18]); w0 = _mm256_fmadd_pd(a0,v2,w0);
294:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[22]); w1 = _mm256_fmadd_pd(a1,v2,w1);
295:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[26]); w2 = _mm256_fmadd_pd(a2,v2,w2);

297:     /* fourth row */
298:     v3 = _mm256_set1_pd(ls[3]);
299:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[27]); w0 = _mm256_fmadd_pd(a3,v3,w0);
300:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[31]); w1 = _mm256_fmadd_pd(a4,v3,w1);
301:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[35]); w2 = _mm256_fmadd_pd(a5,v3,w2);

303:     /* fifth row */
304:     v0 = _mm256_set1_pd(ls[4]);
305:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[36]); w0 = _mm256_fmadd_pd(a0,v0,w0);
306:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[40]); w1 = _mm256_fmadd_pd(a1,v0,w1);
307:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[44]); w2 = _mm256_fmadd_pd(a2,v0,w2);

309:     /* sixth row */
310:     v1 = _mm256_set1_pd(ls[5]);
311:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[45]); w0 = _mm256_fmadd_pd(a3,v1,w0);
312:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[49]); w1 = _mm256_fmadd_pd(a4,v1,w1);
313:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[53]); w2 = _mm256_fmadd_pd(a5,v1,w2);

315:     /* seventh row */
316:     v2 = _mm256_set1_pd(ls[6]);
317:     a0 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[54]); w0 = _mm256_fmadd_pd(a0,v2,w0);
318:     a1 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[58]); w1 = _mm256_fmadd_pd(a1,v2,w1);
319:     a2 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[62]); w2 = _mm256_fmadd_pd(a2,v2,w2);

321:     /* eighth row */
322:     v3 = _mm256_set1_pd(ls[7]);
323:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[63]); w0 = _mm256_fmadd_pd(a3,v3,w0);
324:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[67]); w1 = _mm256_fmadd_pd(a4,v3,w1);
325:     a5 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[71]); w2 = _mm256_fmadd_pd(a5,v3,w2);

327:     /* ninth row */
328:     v0 = _mm256_set1_pd(ls[8]);
329:     a3 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[72]); w0 = _mm256_fmadd_pd(a3,v0,w0);
330:     a4 = _mm256_loadu_pd(&(aa+bs2*adiag[i])[76]); w1 = _mm256_fmadd_pd(a4,v0,w1);
331:     a2 = _mm256_maskload_pd((&(aa+bs2*adiag[i])[80]), _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63));
332:     w2 = _mm256_fmadd_pd(a2,v0,w2);

334:     _mm256_storeu_pd(&(t+i*bs)[0], w0); _mm256_storeu_pd(&(t+i*bs)[4], w1); _mm256_maskstore_pd(&(t+i*bs)[8], _mm256_set_epi64x(0LL, 0LL, 0LL, 1LL<<63), w2);

336:     PetscArraycpy(x+i*bs,t+i*bs,bs);
337:   }

339:   VecRestoreArrayRead(bb,&b);
340:   VecRestoreArray(xx,&x);
341:   PetscLogFlops(2.0*(a->bs2)*(a->nz) - A->rmap->bs*A->cmap->n);
342:   return 0;
343: }
344: #endif