🚀 add more matrix operation in stl/mat.nas

and bp example test
This commit is contained in:
ValKmjolnir 2023-03-01 23:37:13 +08:00
parent cbee5f8705
commit 99434df819
2 changed files with 130 additions and 6 deletions

View File

@ -13,10 +13,9 @@ var mat=func(width,height) {
}
var rand_init=func(a) {
srand();
var ref=a.mat;
forindex(var i;ref) {
ref[i]=rand();
ref[i]=rand()*2-1;
}
}
@ -64,6 +63,23 @@ var sub=func(a,b) {
return res;
}
var hardamard=func(a,b) {
if(a.width!=b.width or a.height!=b.height) {
return nil;
}
var res=mat(a.width,a.height);
var (width,height,ref)=(res.width,res.height,res.mat);
var (aref,bref)=(a.mat,b.mat);
for(var i=0;i<height;i+=1) {
for(var j=0;j<width;j+=1) {
ref[i*width+j]=aref[i*width+j]*bref[i*width+j];
}
}
return res;
}
var neg=func(a) {
var res=mat(a.width,a.height);
var (aref,ref)=(a.mat,res.mat);
@ -73,6 +89,25 @@ var neg=func(a) {
return res;
}
var sum=func(a) {
var res=0;
var aref=a.mat;
forindex(var i;aref) {
res+=aref[i];
}
return res;
}
var mult_num=func(a,c) {
var res=mat(a.width,a.height);
var ref=res.mat;
var aref=a.mat;
forindex(var i;aref) {
ref[i]=aref[i]*c;
}
return res;
}
var trans=func(a) {
var res=mat(a.height,a.width);
var ref=res.mat;
@ -119,12 +154,22 @@ var sigmoid=func(x) {
return 1/(1+t);
}
var diffsigmoid=func(x) {
x=sigmoid(x);
return x*(1-x);
}
var tanh=func(x) {
var t1=math.exp(x);
var t2=math.exp(-x);
return (t1-t2)/(t1+t2);
}
var difftanh=func(x) {
x=tanh(x);
return 1-x*x;
}
var test=func() {
for(var i=0;i<1e4;i+=1) {
var a=mat(4,20);
@ -145,4 +190,83 @@ var test=func() {
}
}
test();
var bp_example=func() {
srand();
var lr=0.01;
var input=[
{width:2,height:1,mat:[0,0]},
{width:2,height:1,mat:[0,1]},
{width:2,height:1,mat:[1,0]},
{width:2,height:1,mat:[1,1]}
];
# last 2 column is useless, only used to make sure bp runs correctly
var expect=[
{width:3,height:1,mat:[0,0,0]},
{width:3,height:1,mat:[1,0,0]},
{width:3,height:1,mat:[1,0,0]},
{width:3,height:1,mat:[0,0,0]}
];
var hidden={
weight:mat(4,2),
bias:mat(4,1),
in:nil,
out:nil,
diff:nil
};
var output={
weight:mat(3,4),
bias:mat(3,1),
in:nil,
out:nil,
diff:nil
};
rand_init(hidden.weight);
rand_init(hidden.bias);
rand_init(output.weight);
rand_init(output.bias);
var epoch=0;
var total=1e6;
while(total>0.01) {
epoch+=1;
if(epoch>1e4) {
println("Training failed after ",epoch," epoch.");
break;
}
total=0;
forindex(var i;input) {
hidden.in=add(mult(input[i],hidden.weight),hidden.bias);
hidden.out=activate(hidden.in,tanh);
output.in=add(mult(hidden.out,output.weight),output.bias);
output.out=activate(output.in,sigmoid);
var error=sub(expect[i],output.out);
output.diff=hardamard(error,activate(output.in,diffsigmoid));
hidden.diff=hardamard(trans(mult(output.weight,trans(output.diff))),activate(hidden.in,difftanh));
output.bias=add(output.bias,output.diff);
hidden.bias=add(hidden.bias,hidden.diff);
output.weight=add(output.weight,mult(trans(hidden.out),output.diff));
hidden.weight=add(hidden.weight,mult(trans(input[i]),hidden.diff));
total+=sum(mult_num(mult(error,trans(error)),0.5));
}
}
if(epoch<=1e4) {
println("Training succeeded after ",epoch," epoch.");
}
forindex(var i;input) {
hidden.in=add(mult(input[i],hidden.weight),hidden.bias);
hidden.out=activate(hidden.in,tanh);
output.in=add(mult(hidden.out,output.weight),output.bias);
output.out=activate(output.in,sigmoid);
println(input[i].mat," : ",output.out.mat);
}
}

View File

@ -33,9 +33,9 @@ var compare=func(){
var filechecksum=func(){
var files=[
"./stl/fg_env.nas", "./stl/file.nas",
"./stl/json.nas",
"./stl/lib.nas", "./stl/list.nas",
"./stl/log.nas", "./stl/module.nas",
"./stl/json.nas", "./stl/lib.nas",
"./stl/list.nas", "./stl/log.nas",
"./stl/mat.nas", "./stl/module.nas",
"./stl/padding.nas", "./stl/process_bar.nas",
"./stl/queue.nas", "./stl/result.nas",
"./stl/sort.nas", "./stl/stack.nas",