ニューラルネットを遺伝的アルゴリズムで最適化する
ニューラルネットワークとニューロエボリューションの勉強.
手始めとして,シンプルな2入力1出力,入力層と出力層だけしかないニューラルネットワークでGAにより結合重みの最適化を行う.課題は,AND,OR,XOR.ちなみにXORはこのネットワークでは絶対に分離不可能.出力層のニューロンには,常に1を与えるニューロンからも入力を受ける.これに結合重みをかけた値が出力ニューロンの閾値となる.
GAの適応度は,教師値と出力値の2乗誤差の逆数.進化手法には,前世代で高い適応度を得た個体を一定数残し,それ以外はルーレット選択により選択し,突然変異として一定範囲の値を一様乱数で重みに加える.
モデルのイメージ.
.
結果は,次のようになりANDとORは最適化できたが,当然だがXORはできない.
AND
Step:0 Error:0.6515803720055237 ・・・ Step:999 Error:1.6564560091306466E-18 input1 input2 supervised output 1.0 1.0 1.0 0.9999999991473791 1.0 0.0 0.0 7.241012645995045E-10 0.0 1.0 0.0 5.971705217988995E-10 0.0 0.0 0.0 3.686834260067064E-28
w11=42.12, w12=41.93, w13=-63.17
OR
Step:0 Error:0.5157082607317045 ・・・ Step:999 Error:2.1032194929622328E-29 input1 input2 supervised output 1.0 1.0 1.0 1.0 1.0 0.0 1.0 0.9999999999999982 0.0 1.0 1.0 0.9999999999999982 0.0 0.0 0.0 3.6848775115804825E-15
w11=67.24, w12=67.26, w13=-33.23
XOR
Step:0 Error:1.007839290787107 ・・・ Step:999 Error:1.00000023344069 input1 input2 supervised output 1.0 1.0 0.0 0.49979318374942117 1.0 0.0 1.0 0.5002616706037233 0.0 1.0 1.0 0.4998447076165615 0.0 0.0 0.0 0.5003131944605926
w11=-2.06, w12=-0.00, w13=0.00
次は勾配法で同じことをする.その後に中間層も入れてGAで最適化をする.
以下,ソース.
package neuroevolution; public class SimpleNeuro { public static double[][][] OR_PATTERN = {{{1.0, 1.0}, {1.0}}, {{1.0, 0.0}, {1.0}}, {{0.0, 1.0}, {1.0}}, {{0.0, 0.0}, {0.0}}}; public static double[][][] AND_PATTERN = {{{1.0, 1.0}, {1.0}}, {{1.0, 0.0}, {0.0}}, {{0.0, 1.0}, {0.0}}, {{0.0, 0.0}, {0.0}}}; public static double[][][] XOR_PATTERN = {{{1.0, 1.0}, {0.0}}, {{1.0, 0.0}, {1.0}}, {{0.0, 1.0}, {1.0}}, {{0.0, 0.0}, {0.0}}}; public static void main(String[] args) { SimpleNeuro simpleneuro = new SimpleNeuro(); simpleneuro.learn(AND_PATTERN); } static int MAX_STEP = 1000; static int POPULATION_SIZE = 100; static double ELITE_RATE = 0.2; static double MUTATION_RATE = 0.5; static double MUTATION_SIZE = 0.3; static int INPUT_NUM = 2; static int OUTPUT_NUM = 1; double[][][] weights; public SimpleNeuro(){ weights = new double[POPULATION_SIZE][INPUT_NUM + 1][OUTPUT_NUM]; for(int p = 0; p < POPULATION_SIZE; p++){ for(int i = 0; i < INPUT_NUM + 1; i++){ for(int o = 0; o < OUTPUT_NUM; o++){ weights[p][i][o] = Math.random() * 2.0 - 1.0; } } } } public void learn(double[][][] patterns){ for(int step = 0; step < MAX_STEP; step++){ System.out.println("Step:" + step); //evaluation double[] errors = evaluate(patterns); // evolution // sort weights sortErrorAndWeights(errors); System.out.println("Error:" + errors[0]); double[][][] newWeights = new double[POPULATION_SIZE][INPUT_NUM + 1][OUTPUT_NUM]; int currentIdx = 0; // add elite individual for(; currentIdx < (int) POPULATION_SIZE * ELITE_RATE; currentIdx++){ copyWeights(weights[currentIdx], newWeights[currentIdx]); } // add by roulette selection double[] fitness = new double[POPULATION_SIZE]; double sumFitness = 0.0; for(int p = 0; p < POPULATION_SIZE; p++){ fitness[p] = 1.0 / errors[p]; sumFitness += fitness[p]; } double[] probFitness = new double[POPULATION_SIZE]; for(int p = 0; p < POPULATION_SIZE; p++){probFitness[p] = fitness[p] / sumFitness;} while(currentIdx < POPULATION_SIZE){ double roulette = Math.random(); double sum = 0.0; int p = 0; for(; p < POPULATION_SIZE; p++){ sum += probFitness[p]; if(roulette < sum){break;} } copyWeights(weights[p], newWeights[currentIdx]); addMutation(newWeights[currentIdx]); currentIdx++; } weights = newWeights; } // show the best NN showDetailOfBestWeights(patterns); } private void showDetailOfBestWeights(double[][][] patterns){ double[] errors = evaluate(patterns); sortErrorAndWeights(errors); for(double[][] pattern : patterns){ double[] inputValue = new double[INPUT_NUM + 1]; for(int i = 0; i < INPUT_NUM; i++){ inputValue[i] = pattern[0][i]; } inputValue[INPUT_NUM] = 1.0; double[] outputValue = calculateOutput(inputValue, weights[0]); for(int i = 0; i < INPUT_NUM; i++){System.out.print(pattern[0][i] + "\t");} for(int o = 0; o < OUTPUT_NUM; o++){System.out.print(pattern[1][o] + "\t");} for(int o = 0; o < OUTPUT_NUM; o++){System.out.print(outputValue[o] + "\t");} System.out.println(); } System.out.println("Error:" + errors[0]); System.out.println("Connection weights of Best Neural Network(row: input, column: output)"); for(int o = 0; o < OUTPUT_NUM; o++){ for(int i = 0; i < INPUT_NUM + 1; i++){ System.out.print(weights[0][i][o] + "\t"); } System.out.println(); } } private double[] evaluate(double[][][] patterns){ double[] errors = new double[POPULATION_SIZE]; for(int p = 0; p < POPULATION_SIZE; p++){ errors[p] = 0.0; for(double[][] pattern : patterns){ double[] inputValue = new double[INPUT_NUM + 1]; for(int i = 0; i < INPUT_NUM; i++){ inputValue[i] = pattern[0][i]; } inputValue[INPUT_NUM] = 1.0; double[] outputValue = calculateOutput(inputValue, weights[p]); for(int o = 0; o < OUTPUT_NUM; o++){ errors[p] += Math.pow(pattern[1][o] - outputValue[o], 2); } } } return errors; } private void sortErrorAndWeights(double[] errors){ for(int p1 = 0; p1 < POPULATION_SIZE - 1; p1++){ int minIdx = p1; for(int p2 = p1 + 1; p2 < POPULATION_SIZE; p2++){ if(errors[minIdx] > errors[p2]){ minIdx = p2; } } double tempError = errors[p1]; errors[p1] = errors[minIdx]; errors[minIdx] = tempError; double[][] tempWeights = new double[INPUT_NUM + 1][OUTPUT_NUM]; copyWeights(weights[p1], tempWeights); copyWeights(weights[minIdx], weights[p1]); copyWeights(tempWeights, weights[minIdx]); } } private void copyWeights(double[][] source, double[][] target){ for(int i = 0; i < INPUT_NUM + 1; i++){ for(int o = 0; o < OUTPUT_NUM; o++){ target[i][o] = source[i][o]; } } } private void addMutation(double[][] weights){ for(int i = 0; i < INPUT_NUM + 1; i++){ for(int o = 0; o < OUTPUT_NUM; o++){ if(Math.random() < MUTATION_RATE){ weights[i][o] += Math.random() * MUTATION_SIZE * 2 - MUTATION_SIZE; } } } } private double[] calculateOutput(double[] inputValue, double[][] weights){ double[] outputValue = new double[OUTPUT_NUM]; for(int o = 0; o < OUTPUT_NUM; o++){ outputValue[o] = 0.0; for(int i = 0; i < INPUT_NUM + 1; i++){ outputValue[o] += inputValue[i] * weights[i][o]; } outputValue[o] = sigmoidFunction(outputValue[o]); } return outputValue; } private double sigmoidFunction(double x){ return 1.0 / (1.0 + Math.pow(Math.E, -x)); } }