Update sgemm kernel 1x4 for C910.
This commit is contained in:
parent
7834c10e2f
commit
a3cac9cca0
|
@ -382,7 +382,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
|||
{
|
||||
BLASLONG i,j,k;
|
||||
FLOAT *C0,*C1,*C2,*C3;
|
||||
FLOAT *ptrba,*ptrbb;
|
||||
FLOAT *ptrba,*ptrbb, *tmpc;
|
||||
|
||||
FLOAT loadb0,loadb1,loadb2,loadb3;
|
||||
FLOAT load0,load1,load2,load3,load4,load5,load6,load7;
|
||||
|
@ -392,6 +392,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
|||
FLOAT res8,res9,res10,res11;
|
||||
FLOAT res12,res13,res14,res15;
|
||||
|
||||
|
||||
for (j=0; j<bn/4; j+=1){
|
||||
C0 = C;
|
||||
C1 = C0+ldc;
|
||||
|
@ -942,53 +943,109 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
|
|||
}
|
||||
if(bm&1){
|
||||
ptrbb = bb;
|
||||
|
||||
res0 = 0;
|
||||
|
||||
res4 = 0;
|
||||
|
||||
res8 = 0;
|
||||
|
||||
res12 = 0;
|
||||
|
||||
for(k=0; k<bk; k+=1){
|
||||
loadb0 = ptrbb[0];
|
||||
loadb1 = ptrbb[1];
|
||||
//t0 for k
|
||||
//ft0-ft3,ft4-ft7,v8-v15 for B, t1-t3 for PB1-3
|
||||
//v0-v3,v4-v7 for A, t4-t6 for PA1-3
|
||||
//v16-v31 for temp C
|
||||
|
||||
load0 = ptrba[0];
|
||||
|
||||
res0 = res0 + load0 * loadb0;
|
||||
FLOAT tmp[4];
|
||||
tmpc=tmp;
|
||||
//t1-t3 for PB
|
||||
//v0-v4 for A, v8-v11 for B
|
||||
//v16-v19 for C
|
||||
asm volatile(
|
||||
"vsetvli zero, zero, e32,m1 \n\t"
|
||||
"fmv.w.x ft11, zero \n\t"
|
||||
|
||||
"vfmv.v.f v16, ft11 \n\t"
|
||||
"vfmv.v.f v17, ft11 \n\t"
|
||||
"vfmv.v.f v18, ft11 \n\t"
|
||||
"vfmv.v.f v19, ft11 \n\t"
|
||||
//unloop 4
|
||||
|
||||
res4 = res4 + load0 * loadb1;
|
||||
"srli t0, %[BK], 2 \n\t"
|
||||
"blez t0, M1x4_TAIL \n\t"
|
||||
|
||||
loadb2 = ptrbb[2];
|
||||
loadb3 = ptrbb[3];
|
||||
|
||||
res8 = res8 + load0 * loadb2;
|
||||
"addi t1, %[PB], 4*4 \n\t"
|
||||
"addi t2, %[PB], 8*4 \n\t"
|
||||
"addi t3, %[PB], 12*4 \n\t"
|
||||
|
||||
".align 4 \n\t"
|
||||
"M1x4_MAINLOOP: \n\t"
|
||||
|
||||
res12 = res12 + load0 * loadb3;
|
||||
"vle.v v4, (%[PA]) \n\t"
|
||||
"addi %[PA], %[PA], 4*4 \n\t"
|
||||
"vrgather.vi v0, v4, 0 \n\t"
|
||||
|
||||
"vle.v v8, (%[PB]) \n\t"
|
||||
"addi %[PB], %[PB], 16*4 \n\t"
|
||||
"vrgather.vi v1, v4, 1 \n\t"
|
||||
|
||||
"vle.v v9, (t1) \n\t"
|
||||
"addi t1, t1, 16*4 \n\t"
|
||||
"vrgather.vi v2, v4, 2 \n\t"
|
||||
|
||||
"vle.v v10, (t2) \n\t"
|
||||
"addi t2, t2, 16*4 \n\t"
|
||||
"vrgather.vi v3, v4, 3 \n\t"
|
||||
|
||||
"vle.v v11, (t3) \n\t"
|
||||
"addi t3, t3, 16*4 \n\t"
|
||||
|
||||
"vfmacc.vv v16, v8, v0 \n\t"
|
||||
"vfmacc.vv v17, v9, v1 \n\t"
|
||||
"vfmacc.vv v18, v10, v2 \n\t"
|
||||
"vfmacc.vv v19, v11, v3 \n\t"
|
||||
|
||||
"addi t0, t0, -1 \n\t"
|
||||
"bgtz t0, M1x4_MAINLOOP \n\t"
|
||||
|
||||
"M1x4_TAIL: \n\t"
|
||||
"andi t0, %[BK], 3 \n\t"
|
||||
"blez t0, M1x4_SAVERESULT \n\t"
|
||||
|
||||
ptrba += 1;
|
||||
ptrbb += 4;
|
||||
}
|
||||
|
||||
res0 = res0 * alpha;
|
||||
"M1x4_TAILLOOP: \n\t"
|
||||
"flw ft0, (%[PA]) \n\t"
|
||||
"addi %[PA], %[PA], 1*4 \n\t"
|
||||
"vle.v v8, (%[PB]) \n\t"
|
||||
"addi %[PB], %[PB], 4*4 \n\t"
|
||||
"vfmv.v.f v0, ft0 \n\t"
|
||||
"vfmacc.vv v16, v8, v0 \n\t"
|
||||
|
||||
"addi t0, t0, -1 \n\t"
|
||||
"bgtz t0, M1x4_TAILLOOP \n\t"
|
||||
|
||||
"M1x4_SAVERESULT: \n\t"
|
||||
//merge v16-v19
|
||||
"vfadd.vv v16, v16, v17 \n\t"
|
||||
"vfadd.vv v18, v18, v19 \n\t"
|
||||
"vfadd.vv v16, v16, v18 \n\t"
|
||||
|
||||
"vfmv.v.f v8, %[ALPHA] \n\t"
|
||||
"vfmul.vv v16, v8, v16 \n\t"
|
||||
"vse.v v16, (%[TMP_C]) \n\t"
|
||||
"M1x4_END: \n\t"
|
||||
:[TMP_C]"+r"(tmpc),
|
||||
[PA]"+r"(ptrba), [PB]"+r"(ptrbb)
|
||||
:[ALPHA]"f"(alpha), [BK]"r"(bk)
|
||||
:"cc", "t0", "t3","t1","t2",
|
||||
"ft0", "ft11",
|
||||
"v0", "v1", "v2", "v3","v4",
|
||||
"v8", "v9", "v10", "v11",
|
||||
"v16", "v17","v18", "v19"
|
||||
);
|
||||
|
||||
res4 = res4 * alpha;
|
||||
|
||||
res8 = res8 * alpha;
|
||||
|
||||
res12 = res12 * alpha;
|
||||
|
||||
C0[0] += res0;
|
||||
C1[0] += res4;
|
||||
C2[0] += res8;
|
||||
C3[0] += res12;
|
||||
C0[0] += tmp[0];
|
||||
C1[0] += tmp[1];
|
||||
C2[0] += tmp[2];
|
||||
C3[0] += tmp[3];
|
||||
|
||||
/* don't need move c point
|
||||
C0 += 1;
|
||||
C1 += 1;
|
||||
C2 += 1;
|
||||
C3 += 1;
|
||||
*/
|
||||
}
|
||||
|
||||
k = bk<<2;
|
||||
|
|
Loading…
Reference in New Issue