summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Marc Valin <jean-marc.valin@octasic.com>2011-05-30 13:46:27 -0400
committerJean-Marc Valin <jean-marc.valin@octasic.com>2011-05-30 13:46:27 -0400
commit10b5fe0259e4a195d0e86076c8236a31b820ca7b (patch)
treeb6a7ac2dbe42b88f614338587ef8ae6b994daa9c
parent2b68c4efe31cdd77518172793c6aed2c075d3206 (diff)
downloadopus-exp_detection.tar.gz
Training now stops when stuck in a minimumexp_detection
-rw-r--r--src/mlp_train.c29
1 files changed, 19 insertions, 10 deletions
diff --git a/src/mlp_train.c b/src/mlp_train.c
index 1c1fa28e..a9d548b1 100644
--- a/src/mlp_train.c
+++ b/src/mlp_train.c
@@ -80,7 +80,7 @@ MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int
std = .001;
std = 1/sqrt(inDim*std);
for (k=0;k<topo[1];k++)
- net->weights[0][k*(topo[0]+1)+j+1] = randn(4*std);
+ net->weights[0][k*(topo[0]+1)+j+1] = randn(std);
}
net->in_rate[0] = 1;
for (j=0;j<topo[1];j++)
@@ -223,7 +223,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
{
int i, j;
int e;
- float last_rms = 1e10;
+ float best_rms = 1e10;
int inDim, outDim, hiddenDim;
int *topo;
double *W0, *W1, *best_W0, *best_W1;
@@ -241,6 +241,8 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
pthread_t thread[NB_THREADS];
int samplePerPart = nbSamples/NB_THREADS;
int count_worse=0;
+ int count_retries=0;
+
topo = net->topo;
inDim = net->topo[0];
hiddenDim = net->topo[1];
@@ -313,10 +315,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
float mean_rate = 0, min_rate = 1e10;
rms = (rms/(outDim*nbSamples));
error_rate = (error_rate/(outDim*nbSamples));
- fprintf (stderr, "%f (%f %f) ", error_rate, rms, last_rms);
- if (rms < last_rms)
+ fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms);
+ if (rms < best_rms)
{
- last_rms = rms;
+ best_rms = rms;
for (i=0;i<W0_size;i++)
{
best_W0[i] = W0[i];
@@ -328,10 +330,12 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
best_W1_rate[i] = W1_rate[i];
}
count_worse=0;
- } else if (rms > last_rms) {
+ count_retries=0;
+ } else {
count_worse++;
- if (count_worse>20)
+ if (count_worse>30)
{
+ count_retries++;
count_worse=0;
for (i=0;i<W0_size;i++)
{
@@ -344,13 +348,15 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
for (i=0;i<W1_size;i++)
{
W1[i] = best_W1[i];
- best_W1_rate[i] *= .7;
+ best_W1_rate[i] *= .8;
if (best_W1_rate[i]<1e-15) best_W1_rate[i]=1e-15;
W1_rate[i] = best_W1_rate[i];
W1_grad[i] = 0;
}
}
}
+ if (count_retries>10)
+ break;
for (i=0;i<W0_size;i++)
{
if (W0_oldgrad[i]*W0_grad[i] > 0)
@@ -386,7 +392,10 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
W1[i] += W1_grad[i]*W1_rate[i];
}
mean_rate /= (topo[0]+1)*topo[1] + (topo[1]+1)*topo[2];
- fprintf (stderr, "%g (min %g) %d\n", mean_rate, min_rate, e);
+ fprintf (stderr, "%g %d", mean_rate, e);
+ if (count_retries)
+ fprintf(stderr, " %d", count_retries);
+ fprintf(stderr, "\n");
if (stopped)
break;
}
@@ -403,7 +412,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
free(W1_grad);
free(W0_rate);
free(W1_rate);
- return last_rms;
+ return best_rms;
}
int main(int argc, char **argv)