ニューラルネットを遺伝的アルゴリズムで最適化する

ニューラルネットワークとニューロエボリューションの勉強.
手始めとして,シンプルな2入力1出力,入力層と出力層だけしかないニューラルネットワークでGAにより結合重みの最適化を行う.課題は,AND,OR,XOR.ちなみにXORはこのネットワークでは絶対に分離不可能.出力層のニューロンには,常に1を与えるニューロンからも入力を受ける.これに結合重みをかけた値が出力ニューロン閾値となる.
GAの適応度は,教師値と出力値の2乗誤差の逆数.進化手法には,前世代で高い適応度を得た個体を一定数残し,それ以外はルーレット選択により選択し,突然変異として一定範囲の値を一様乱数で重みに加える.

モデルのイメージ.
f:id:axxasusi:20140611194645p:plain

結果は,次のようになりANDとORは最適化できたが,当然だがXORはできない.

AND
Step:0
Error:0.6515803720055237
・・・
Step:999
Error:1.6564560091306466E-18
input1	input2	supervised	output
1.0	1.0	1.0	0.9999999991473791	
1.0	0.0	0.0	7.241012645995045E-10	
0.0	1.0	0.0	5.971705217988995E-10	
0.0	0.0	0.0	3.686834260067064E-28	

w11=42.12, w12=41.93, w13=-63.17

OR
Step:0
Error:0.5157082607317045
・・・
Step:999
Error:2.1032194929622328E-29
input1	input2	supervised	output
1.0	1.0	1.0	1.0	
1.0	0.0	1.0	0.9999999999999982	
0.0	1.0	1.0	0.9999999999999982	
0.0	0.0	0.0	3.6848775115804825E-15	

w11=67.24, w12=67.26, w13=-33.23

XOR
Step:0
Error:1.007839290787107
・・・
Step:999
Error:1.00000023344069
input1	input2	supervised	output
1.0	1.0	0.0	0.49979318374942117	
1.0	0.0	1.0	0.5002616706037233	
0.0	1.0	1.0	0.4998447076165615	
0.0	0.0	0.0	0.5003131944605926	

w11=-2.06, w12=-0.00, w13=0.00

次は勾配法で同じことをする.その後に中間層も入れてGAで最適化をする.

以下,ソース.

package neuroevolution;

public class SimpleNeuro {

	public static double[][][] OR_PATTERN = {{{1.0, 1.0}, {1.0}},
											{{1.0, 0.0}, {1.0}},
											{{0.0, 1.0}, {1.0}},
											{{0.0, 0.0}, {0.0}}};
	public static double[][][] AND_PATTERN = {{{1.0, 1.0}, {1.0}},
											 {{1.0, 0.0}, {0.0}},
											 {{0.0, 1.0}, {0.0}},
											 {{0.0, 0.0}, {0.0}}};
	public static double[][][] XOR_PATTERN = {{{1.0, 1.0}, {0.0}},
											 {{1.0, 0.0}, {1.0}},
											 {{0.0, 1.0}, {1.0}},
											 {{0.0, 0.0}, {0.0}}};
	
	public static void main(String[] args) {
		SimpleNeuro simpleneuro = new SimpleNeuro();
		simpleneuro.learn(AND_PATTERN);	
	}
	
	static int MAX_STEP = 1000;
	static int POPULATION_SIZE = 100;
	
	static double ELITE_RATE = 0.2;
	static double MUTATION_RATE = 0.5;
	static double MUTATION_SIZE = 0.3;

	static int INPUT_NUM = 2;
	static int OUTPUT_NUM = 1;
	
	double[][][] weights;
	
	public SimpleNeuro(){
		weights = new double[POPULATION_SIZE][INPUT_NUM + 1][OUTPUT_NUM];
		for(int p = 0; p < POPULATION_SIZE; p++){
			for(int i = 0; i < INPUT_NUM + 1; i++){
				for(int o = 0; o < OUTPUT_NUM; o++){
					weights[p][i][o] = Math.random() * 2.0 - 1.0;
				}
			}
		}
	}
	
	public void learn(double[][][] patterns){
		for(int step = 0; step < MAX_STEP; step++){
			System.out.println("Step:" + step);
			//evaluation
			double[] errors = evaluate(patterns);
			
			// evolution
			// sort weights
			sortErrorAndWeights(errors);
			System.out.println("Error:" + errors[0]);
			
			double[][][] newWeights = new double[POPULATION_SIZE][INPUT_NUM + 1][OUTPUT_NUM];
			int currentIdx = 0;
			// add elite individual
			for(; currentIdx < (int) POPULATION_SIZE * ELITE_RATE; currentIdx++){
				copyWeights(weights[currentIdx], newWeights[currentIdx]);
			}
			// add by roulette selection
			double[] fitness = new double[POPULATION_SIZE];
			double sumFitness = 0.0;
			for(int p = 0; p < POPULATION_SIZE; p++){
				fitness[p] = 1.0 / errors[p];
				sumFitness += fitness[p];
			}
			double[] probFitness = new double[POPULATION_SIZE];
			for(int p = 0; p < POPULATION_SIZE; p++){probFitness[p] = fitness[p] / sumFitness;}
			while(currentIdx < POPULATION_SIZE){
				double roulette = Math.random();
				double sum = 0.0;
				int p = 0;
				for(; p < POPULATION_SIZE; p++){
					sum += probFitness[p];
					if(roulette < sum){break;}
				}
				copyWeights(weights[p], newWeights[currentIdx]);
				addMutation(newWeights[currentIdx]);
				currentIdx++;
			}
			weights = newWeights;
		}
		
		// show the best NN
		showDetailOfBestWeights(patterns);
	}
	
	private void showDetailOfBestWeights(double[][][] patterns){
		double[] errors = evaluate(patterns);
		sortErrorAndWeights(errors);
		for(double[][] pattern : patterns){
			double[] inputValue = new double[INPUT_NUM + 1];
			for(int i = 0; i < INPUT_NUM; i++){
				inputValue[i] = pattern[0][i];
			}
			inputValue[INPUT_NUM] = 1.0;
			double[] outputValue = calculateOutput(inputValue, weights[0]);
			for(int i = 0; i < INPUT_NUM; i++){System.out.print(pattern[0][i] + "\t");}
			for(int o = 0; o < OUTPUT_NUM; o++){System.out.print(pattern[1][o] + "\t");}
			for(int o = 0; o < OUTPUT_NUM; o++){System.out.print(outputValue[o] + "\t");}
			System.out.println();
		}
		System.out.println("Error:" + errors[0]);
		System.out.println("Connection weights of Best Neural Network(row: input, column: output)");
		for(int o = 0; o < OUTPUT_NUM; o++){
			for(int i = 0; i < INPUT_NUM + 1; i++){
				System.out.print(weights[0][i][o] + "\t");
			}
			System.out.println();
		}
	}
	
	private double[] evaluate(double[][][] patterns){
		double[] errors = new double[POPULATION_SIZE];

		for(int p = 0; p < POPULATION_SIZE; p++){
			errors[p] = 0.0;
			for(double[][] pattern : patterns){
				double[] inputValue = new double[INPUT_NUM + 1];
				for(int i = 0; i < INPUT_NUM; i++){
					inputValue[i] = pattern[0][i];
				}
				inputValue[INPUT_NUM] = 1.0;
				double[] outputValue = calculateOutput(inputValue, weights[p]);
				for(int o = 0; o < OUTPUT_NUM; o++){
					errors[p] += Math.pow(pattern[1][o] - outputValue[o], 2);
				}
			}
		}
		return errors;
	}
	
	private void sortErrorAndWeights(double[] errors){
		for(int p1 = 0; p1 < POPULATION_SIZE - 1; p1++){
			int minIdx = p1;
			for(int p2 = p1 + 1; p2 < POPULATION_SIZE; p2++){
				if(errors[minIdx] > errors[p2]){
					minIdx = p2;
				}
			}
			double tempError = errors[p1];
			errors[p1] = errors[minIdx];
			errors[minIdx] = tempError;
			double[][] tempWeights = new double[INPUT_NUM + 1][OUTPUT_NUM];
			copyWeights(weights[p1], tempWeights);
			copyWeights(weights[minIdx], weights[p1]);
			copyWeights(tempWeights, weights[minIdx]);
		}
	}
	
	private void copyWeights(double[][] source, double[][] target){
		for(int i = 0; i < INPUT_NUM + 1; i++){
			for(int o = 0; o < OUTPUT_NUM; o++){
				target[i][o] = source[i][o];
			}
		}	
	}
	
	private void addMutation(double[][] weights){
		for(int i = 0; i < INPUT_NUM + 1; i++){
			for(int o = 0; o < OUTPUT_NUM; o++){
				if(Math.random() < MUTATION_RATE){
					weights[i][o] += Math.random() * MUTATION_SIZE * 2 - MUTATION_SIZE; 
				}
			}
		}
	}
	
	private double[] calculateOutput(double[] inputValue, double[][] weights){
		double[] outputValue = new double[OUTPUT_NUM];
		for(int o = 0; o < OUTPUT_NUM; o++){
			outputValue[o] = 0.0;
			for(int i = 0; i < INPUT_NUM + 1; i++){
				outputValue[o] += inputValue[i] * weights[i][o];
			}
			outputValue[o] = sigmoidFunction(outputValue[o]);
		}
		return outputValue;
	}
	
	private double sigmoidFunction(double x){
		return 1.0 / (1.0 + Math.pow(Math.E, -x));
	}
}