Disclosure of Invention
The invention aims to solve the problems that the DRBM network under-fitting and the initial values of network parameters can not lead the network to achieve global optimum after training under the condition of small samples.
The idea of the invention is to divide the algorithm into two stages of meta learning and model learning by improving the training-testing algorithm of the network. And updating the network parameters by using the training tasks in the meta-learning stage, taking the updated network parameters as initial values of the network parameters in the model learning stage, enabling the initial values of the network parameters to enable a loss function of network training to descend more quickly and to achieve global optimization more easily, and updating the network parameters by using the testing tasks in the model learning stage and testing. The algorithm introduces a meta-learning method to improve the training process of the DRBM, so that the gradient descending direction of the meta-learning stage of the network parameters is reduced to the most adaptive point, and the network can be quickly adapted to a new task.
The technical scheme adopted by the invention for solving the technical problems is as follows: a DRBM fast adaptation method based on meta-learning comprises the following steps:
s1, establishing a DRBM network structure. The network structure of the DRBM can be divided into three layers, namely a visible layer, a hidden layer and a classification layer, wherein each layer comprises a plurality of neurons, the connection mode of the neurons is that nodes in the same layer are not connected, and the nodes between the layers are mutually connected in a full connection mode. The state of each neuron is a binary value of 1 or 0,1 represents activation, 0 represents non-activation, and activation means that the node represented by the neuron processes data. The distribution of the DRBM is determined by the value of a neuron, wherein a visible layer is used for representing input data, the number of nodes of the visible layer is determined by the dimension of the input data, and the value of the nodes of the visible layer is the value of each dimension of the input data; the hidden layer obtains certain statistical characteristics of the observed data in an optimized mode, and the number of nodes of the hidden layer is manually adjusted according to different data and tasks; and the classification layer unit judges the class according to the data characteristics extracted by the hidden layer unit, and the number of classification layer nodes is determined by the number of data classes.
DRBM networks are described by network parameters. Assuming that the number of nodes of a visible layer is l, the number of nodes of a hidden layer is m, the number of nodes of a classification layer is n, a bias vector of the visible layer is b, b is a 1-row/column vector and a bias vector of the hidden layer is c, c is a 1-row/column vector and a bias vector of the classification layer is d, d is a 1-row/n-column vector, a weight matrix of an input layer and the hidden layer is W, W is a l-row/m-column matrix, a weight matrix of the classification layer and the hidden layer is U, and U is an n-row/m-column matrix. Let the vector θ be (W, U, b, c, d), the purpose of training the DRBM network is to find the best value of θ to predict the data class through the network.
DRBM is a model determined based on an energy function, which can be defined as:
E(y,x,h)=-hWTxT-bxT-chT-dey T-hUTyT (1)
wherein x represents the state vector of the visible layer, x is a 1-row/column vector, h represents the state vector of the hidden layer unit, h is a 1-row m-column vector, y represents the state vector of the classification layer, y is a 1-row n-column vector, and y is the "one-hot" (one-hot) type representation of the label, that is, only one node in all nodes is 1, and the rest nodes are 0. The joint probability distribution of x, h, y is:
wherein
Referred to as a partition function.
S2 meta learning stage:
s2.1 initializing network parameters:
initializing a bias vector b of a visible layer to be a zero matrix of 1 row and l column, a bias vector c of a hidden layer to be a zero matrix of 1 row and m column, a bias vector d of a classification layer to be a zero matrix of 1 row and n column, corresponding gradients of delta b to be a zero matrix of 1 row and l column, delta c to be a zero matrix of 1 row and m column and delta d to be a zero matrix of 1 row and n column, a bias matrix W of the visible layer and the hidden layer to be a zero matrix of l row and m column, a bias matrix U of the classification layer and the hidden layer to be a zero matrix of n row and m column, corresponding gradients of delta W to be a zero matrix of l row and m column and delta U to be a zero matrix of n row and m column, and a network parameter theta initial value is recorded as theta0(ii) a Setting an internal learning rate alpha to 0.01-0.5, an external learning rate beta to 0.005-0.05, a momentum (momentum) learning rate m to 0.5 and a penalty (penalty) coefficient p to 10-4。
S2.2, completing the training of a task (task):
the training of the meta-learning phase takes one task as a basic unit, and each task comprises two parts, namely a support set (support set) and a challenge set (query set). The process of training with the support set is called internal learning, and the process of training with the challenge set is called external learning. The data categories for each task may be the same as or different from the other tasks. All tasks of the meta-learning phase together constitute a training task (training tasks). All samples in the training task contain both data information and label information. The method comprises the following specific steps:
s2.2.1 having the support set as the network input, θ0As the initial value of the network parameter, completing one training:
s2.2.1.1 calculate the probability distribution function of the hidden layer:
p(h|y,x)=sigmoid(x(0)W+y(0)U+c) (3)
wherein x
(0)To input data, y
(0)An input data tag in the form of a one-hot,
s2.2.1.2, obtaining the hidden layer probability distribution function, and obtaining the hidden layer node value by utilizing Gibbs sampling. The specific method for Gibbs sampling comprises the following steps: generating m [0,1 ]]Random number r ofi,i∈[1,m]To hide the node number of the layer, if p (h)i|y,x)>riThen node hiIs 1, otherwise is 0, the resulting hidden layer node distribution is denoted as h(0)H is sampled from the probability distribution p (h | y, x)(0)Is recorded as h(0)~p(h|y,x)。
S2.2.1.3, reconstructing the visible layer and the classification layer according to the hidden layer node values, and respectively calculating the probability distribution functions of the visible layer and the classification layer:
wherein
c' represents all possible values of c.
S2.2.1.4 obtaining probability distribution function of visible layer and classification layer, using Gibbs to sample x(1)P (x | h) and y(1)-p (y | h) obtaining node values x of the visible layer and the classification layer(1)、y(1)。
S2.2.1.5 taking value x according to nodes of visible layer and classification layer(1)、y(1)The hidden layer probability distribution function is calculated again:
p(h|y,x)=sigmoid(x(1)W+y(1)U+c) (5)
and sampled by GibbsTo h(1)~p(h|y,x)。
S2.2.1.6 according to x(0)、y(0)、x(1)、y(1)、h(0)、h(1)And solving an updating gradient of the network parameter theta:
s2.2.1.7, inputting the tasks in the training task set, and correcting the gradient according to the internal learning rate alpha, the momentum learning rate m and the penalty coefficient p:
wherein i belongs to [1, ns ], and ns is the number of samples in the support set;
s2.2.1.8, the network parameter θ is updated according to the gradient:
the network parameter after inputting the support set data for training and updating is recorded as thetans。
S2.2.2 takes a challenge set as a network input, θsAs the initial value of the network parameter, completing one training:
according to formulas (3), (4), (5) and (6), probability distribution function calculation, node distribution sampling and network parameter updating gradient calculation are completed in sequence, and the correction gradient is calculated according to formula (9):
wherein beta is external learning rate, i belongs to [1, nq ]]And nq is the number of samples in the challenge set. Finally, the network parameter is updated according to the formula (8), and the obtained network parameter is marked as thetanq。
After the training of one task is completed, the network parameters are updated once, and the updating only keeps the external learning part, namely:
θt+1=θt+(θnq-θns) (10)
wherein t is equal to [1, nt ∈]T represents the tth training task, nt represents the number of training tasks, and the network parameter theta is storedt+1And as the initial value of the network parameter of the next task, finishing the training of one task.
S2.3 completes all traversals (epoch):
each traversal needs to train a plurality of tasks, the number of the tasks is determined according to the size of the data set, and generally 20-100 groups of tasks can be set. And after the training of one task is finished, taking the updated initial value of the network parameter as the initial value of the network parameter of the next task, and repeating the S2.2 process in sequence until all the tasks finish one training, namely one traversal. And after one traversal is completed, taking the updated network parameters as the initial values of the network parameters of the next traversal until all the traversals are completed. The finally obtained network parameter is recorded as thetant。
The meta-learning stage usually needs to complete multiple traversals, the number of traversals depends on the convergence speed of the network, and the number of traversals is usually set to 50.
S3 model learning phase:
similar to the meta-learning phase, the tasks of the model learning phase also include a support set and a challenge set, but the model learning phase typically has only one task, called a test task, in which the data classes are typically different from the training tasks.
S3.1 taking the support set as network input, θntAs the initial value of the network parameter, completing one training:
according to formulas (3), (4), (5) and (6), probability distribution function calculation, node distribution sampling and network parameter updating gradient calculation are completed in sequence, and then the correction gradient is calculated according to formula (11):
wherein i belongs to [1, ts ], ts is the number of samples in the support set of the test task, and finally, the network parameter updating is completed according to a formula (8).
S3.2, completing all traversals:
each traversal is to complete one training of the support set in the test task, after one traversal is completed, the updated network parameters are used as the initial values of the network parameters of the next traversal until all traversals are completed, and finally the obtained network parameters are recorded as thetat。
The model learning stage generally needs to complete 50-150 times of traversal, and the traversal time depends on the convergence speed of the network.
S3.3 inputting the data in the challenge set into the network, thetatAs network parameters, the prediction probability of the ith category is sequentially calculated:
prediction(i)=repeat(d(i),tq)+log(exp(x(0)·W+T(i)·U+c)+1) (12)
wherein T is tq row, nc column matrix, tq is the number of the samples of the challenge set, nc is the total number of the classes, and T (i) represents that all the rest columns of the offset vector T except the ith column is 1 are 0; d (i) indicates that all the rest columns of the offset vector d except the ith column are 1 are 0; repeat (d (i), tq) represents that the offset vector d (i) is repeated tq times to become a matrix of tq rows and n columns.
And after the prediction probabilities of the nc categories are calculated, taking the column where the maximum value is located as a category prediction result, and finishing the target classification.
Through relevant experiments, the beneficial effects obtained by the invention are as follows:
(1) after the training of the meta-learning stage, the initial value of the network parameter can be closer to the fastest convergence point, the test sample can be well fitted only by inputting a small number of samples to carry out the training fine tuning of the network, and the network can be more easily subjected to global optimization after the training fine tuning;
(2) the feature expression capability of the network can be enhanced, and the network is less prone to under-fitting when trained by using various tasks;
(3) the characteristic that the algorithm can enable network parameters to be rapidly converged can be applied to the condition of small samples so as to improve the identification accuracy of the network;
(4) not only as a training-testing method (test task is different from training task type), but also as a pre-training algorithm (test task is same as training task type).
Detailed Description
The invention is further illustrated with reference to the accompanying drawings:
example 1 of the present invention demonstrates the complete recognition procedure of the proposed algorithm for HRRP data and the comparison of the recognition results with the conventional algorithm. Example 2 demonstrates the recognition of MNIST data sets by the proposed algorithm in comparison to conventional algorithms.
Fig. 1 shows a network structure of the present invention, and the network is divided into a visible layer (v layer), a label layer (y layer) and a hidden layer (h layer). v ═ v (v)1,v2,…,vl)、h=(h1,h2,…,hm) And y ∈ {0,1}nThe state vectors of the visible layer, the hidden layer and the label layer are respectively, and the label layer vector y belongs to {0,1}nIn the form of one-hot. In this example, the visible layer, hidden layer and label layer node numbers are 201, 130 and 3, respectively.
Fig. 2 is an algorithm flow chart showing a standard meta-learning based fast adaptation DRBM algorithm flow.
FIG. 3 is a flowchart illustrating the identification process of HRRP data according to the present invention, and comparing with FIG. 2, the data processing stages are added. The method is that the DRBM network can only input binary data, so data preprocessing is firstly carried out after HRRP data is obtained, and the method comprises the following four steps: the first step is data sorting, namely, data of different pitch angles and azimuth angles are selected according to different experiments, and then the data are arranged randomly; the second step is that the range normalization overcomes the distance sensitivity of the signal, namely, the range normalization is carried out according to the sample with the maximum range in the data set; the third step is mean value normalization to enhance the learning ability of the network to different characteristics. The meaning of the HRRP data mean normalization lies in two aspects: 1) after the average value is subtracted, the residual part can be intuitively understood as the difference part of each sample; 2) in the process of updating the network parameters, the algorithm has smaller oscillation and is easier to converge; and fourthly, carrying out data binarization, and converting the amplitude into probability to carry out sampling to obtain binarized HRRP data.
HRRP simulation data output by three-dimensional airplane model electromagnetic simulation software designed by Saibo corporation are used in example 1, the types of the simulated airplanes are F-35, F-117 and P-51, and specific parameters of the airplane are shown in figure 4.
Wherein the radar simulation wave band is an x wave band. The frequency range is 9.5GHz-10.5GHz, and the step size is 5 MHz. The polarization mode is vertical polarization. The target pitch angle is 0-10 degrees, the step length is 0.1 degrees, the azimuth angle is 0-90 degrees, and the step length is 0.1 degrees. Therefore, the total of our data aggregation is 201 frequency points, 101 elevation angles and 901 azimuth angles, i.e. 101 × 901 equals 91001 samples, each sample having 201 dimensions.
Fig. 5 illustrates the change in HRRP data during preprocessing, where the x-axis represents the data dimension and the y-axis represents the signal amplitude. (a) The HRRP data before preprocessing, (b) the normalized HRRP data, and (c) the binarized HRRP data.
FIG. 6 shows a data set of 8100 HRRP data, where the x-axis represents data dimensions, the y-axis represents sample number, the z-axis represents signal amplitude, the 1 st to 2700 th samples are F-35 aircraft HRRPs, the 2701 st to 5400 th samples are F-117 aircraft HRRPs, and the 5401 th to 8100 th samples are P-51 aircraft HRRPs. (a) The processed HRRP data satisfy 0-1 distribution and the distribution probability is the same as the original amplitude intensity.
In example 1, we selected from the simulation data of 0 ° to 10 ° and 80 ° to 90 ° for each type of airplane, 3 ° to 5 ° for the pitch angle, and 0.1 ° for each step, and a total of 12726 samples of (101+101) × 21 ═ 4242 samples for each type of airplane, where multiples of 0.2 ° for the azimuth angle were used as the test set
6363 samples in total, the rest being training set
A total of 6363 samples.
After the data processing is completed, we will go into the training and testing phase of the algorithm, and in example 1 we will perform two sets of experiments:
experiment 1: training DRBM using a conventional algorithm from
Randomly extracting n samples of three kinds of targets, training, and collecting the training data
Randomly extracting 2000 samples of each of the three types of targets and testing;
experiment 2: DRBM is trained and tested using the algorithm of the present invention, from
Randomly extracting n samples of three types of targets and training, wherein a support set accounts for 1/4 of the number of samples in each task, and a challenge set accounts for 3/4 of the number of samples; from
Randomly extracting n samples of three types of targets from the residual data as a support set in a test set, training the support set, and extracting the n samples from the three types of targets
Randomly extracting 2000 samples of each of the three types of targets as a challenge set and testing;
the number of samples n was in the range of [20,400], and each set of experiments was performed 50 times (epoch is 50).
Experiment 2 the specific training and testing procedure was:
s1 meta learning stage:
initializing parameters: visible layer bias vector b ═ b1,b2,…,b201) And (c) hidden layer bias vector c ═ c1,c2,…,c130) And the bias vector d ═ of the label layer (d)1,d2,d3) The weight matrix W between the visible layer and the hidden layer is (W)i,j)∈R201×130And a weight matrix U between the label layer and the hidden layer is (U)i,j)∈R3×130The network parameters are zero matrixes of corresponding dimensionality respectively, and the gradients corresponding to the network parameters are zero vectors of the corresponding dimensionality. The internal learning rate α is 0.1, the external learning rate β is 0.001, the momentum learning rate m is 0.5, and the penalty factor p is 10-4。
Will be provided with
1/4 of each n samples of three types of targets are randomly extracted as a support set, 3/4 is used as an inquiry set, and 50 times of traversal training is sequentially completed according to the steps of S1.2 and S1.3 to obtain a trained network parameter initial value theta
0。
S2 model learning phase:
will be selected from
Randomly extracting n samples of three types of targets from the residual data as a support set, training, finishing fine adjustment of a network parameter theta according to the steps of S2.1 and S2.2, and finishing trainingThe resultant network parameter is recorded as θ
s。
Will be selected from
And randomly drawing 2000 samples of each of the three types of targets as a challenge set and testing.
Fig. 7 shows the results of two experiments, where the circle and the square in the graph are the average of the classification correctness obtained by repeating the experiment for 20 times when n takes different values, and the band on each node is the classification correctness interval of 50 experiments. The following conclusions can be drawn from the figure:
1) in experiment 2, the present invention was used as a pre-training method because the training set had the same data type as the test set. It can be seen that when n takes different values, the classification accuracy of experiment 2 is higher than that of experiment 1, which means that the network parameter initial value obtained by meta-learning makes the network more easily reach global optimum than the randomized initial value;
2) under the condition of small samples (n is less than or equal to 200), the method has obvious improvement on the classification performance, which shows that the initial value of the network parameter obtained by meta-learning is faster to converge than the initial value of randomization;
3) when n takes different values, the classification accuracy intervals of experiment 2 are all smaller than those of experiment 1, which shows that the model obtained by training of the invention has stronger stability.
Experiments were performed in example 2 using a handwritten digit (MNIST) data set. The MNIST data set consists of 7 ten thousand pictures and corresponding labels, wherein 6 ten thousand pictures are used for training the neural network, and 1 ten thousand pictures are used for testing the neural network. Each picture is a 0-9 handwritten digital picture with 28 x 28 pixel points. The picture is a white character with black background, black is represented by 0, white is represented by a floating point number between 0 and 1, and the closer to 1, the more white the color is. The MNIST dataset provides a label for each picture, which is given in a unique form, i.e. the label vector is a one-dimensional array with a length of 10. A sample of the MNIST dataset is shown in fig. 8.
784 pixel points are combined into a one-dimensional array with the length of 784, and the one-dimensional array is input data to be input into a neural network. For the RBM network, input data of the network can only be binary, so that an MNIST data set needs to be preprocessed, random binarization processing is performed according to black and white brightness of each pixel point, namely, sampling is performed according to probability of converting white intensity value of each pixel point, and the binary distribution is met.
We performed three sets of experiments in total:
experiment 1: training the DRBM by using a traditional algorithm, training n numbers of 0-9 respectively, and testing 5000 numbers of 8-9 respectively;
experiment 2: training the DRBM by using a traditional algorithm, training n numbers of 8-9 respectively, and testing 5000 numbers of 8-9 respectively;
experiment 3: training and testing the DRBM by using the algorithm provided by the invention, wherein a training task set in a training stage comprises n numbers of 0-7, wherein a support set accounts for 1/4 of the number of samples in each task, and a challenge set accounts for 3/4 of the number of samples; in the testing stage, n numbers of 8-9 in the supporting set are used for carrying out parameter fine adjustment, and 5000 numbers of 8-9 in the inquiry set are used for testing.
Example 2 the steps of the specific training and testing are similar to example 1, except that the number of network nodes and the learning rate are set differently. The value range of the number n of each sample is [5,1000], each experiment is repeated for 20 times, and the result is shown in fig. 9 by drawing according to the interval of the classification accuracy:
in the graph, each node is a classification accuracy average value obtained by repeating 20 times of experiments when n takes different values, and a band on each node is a classification accuracy interval of 20 experiments.
The following conclusions can be drawn from the figure:
under small sample conditions (n ≦ 100):
comparing experiment 1 and experiment 2, when the class of the training data is more than that of the test data, the classification accuracy of the network on the test data is reduced. The reason is that the feature expression capability is limited due to the limitation of the network structure, and the more training task types, the more the network is likely to fall into under-fitting, so that the recognition capability of the network for a single task is reduced.
Comparing experiment 1 with experiment 3, although the sample category and number participating in network training are the same, in experiment 3, when n takes different values, the classification accuracy is higher than that in experiment 1, especially under the condition of small sample. This is because the proposed algorithm tends to find the "fastest convergence point" more in the training phase than to approach the optimal point. Good initial parameters can be learned even if a number different from the test data category is input for network training. And when n takes different values, the classification accuracy intervals of the experiment 3 are all smaller than those of the experiment 1, which shows that the model obtained by the algorithm training is stronger in stability.
Comparing experiment 2 with experiment 3, when n takes different values, the classification accuracy of experiment 3 is higher than that of experiment 2, which shows that the algorithm does not reduce the feature expression ability of the network due to too many training categories, and also shows that the initial values of the network parameters learned in the training stage of experiment 3 give consideration to the learning ability of different tasks, and the network can better fit the training data only by a small amount of training.
When the number of samples is sufficient (n is 1000), the classification accuracy of experiment 3 is higher than that of the other two samples, which indicates that the network parameters reach the global optimal point by the meta-learning method.
In conclusion, the method can effectively improve the DRBM classification accuracy under the condition of small samples, and has higher engineering application value.