/**
 *
 * This program is to simulate back propagation learning of the NOT-XOR logic gate.
 * My test produced the following output:
 *
 * !xor(0.0,0.0) = 0.96676975
 * !xor(0.0,1.0) = 0.034594968
 * !xor(1.0,0.0) = 0.034695894
 * !xor(1.0,1.0) = 0.95619357
 *
 * program can be made to learn a XOR logic gate by changing the training samples.
 * code was not written to be professional but for readability. 
 *
 * www.tech-algorithm.com
 *
 */
public class BackPropagation {
    
    /** return 1/(1 + exp(-x)) */
    private float sigmoid(float x) {        
        return 
                (float)(1.0/(1.0 + Math.exp(-x))) ;        
    }        
    
    /** get a random number in the given range */
    private float getRandom(float x) {
        
        return
                (float)(((Math.random()*2.0)-1.0)*x) ;
    }
        
    public BackPropagation() {
        
        float wrange = 0.1f ;
        
        // initialize all weights in the range (-0.1, 0.1)
        float
                w1 = getRandom(wrange), 
                w2 = getRandom(wrange),
                w3 = getRandom(wrange),
                w4 = getRandom(wrange),
                w5 = getRandom(wrange),
                w6 = getRandom(wrange),
                w7 = getRandom(wrange),
                w8 = getRandom(wrange),
                w9 = getRandom(wrange),
                wbias1 = getRandom(wrange),
                wbias2 = getRandom(wrange),
                wbias3 = getRandom(wrange),
                wbias4 = getRandom(wrange) ;
                                        
        // training sets for NOT-XOR logic gate
        float[][] sampleset = {            
            { 0.0f, 0.0f, 1.0f },
            { 0.0f, 1.0f, 0.0f },
            { 1.0f, 0.0f, 0.0f },
            { 1.0f, 1.0f, 1.0f }            
        } ;                
        
        // training rate
        float n = 0.35f ;
                
        // train for this many epochs. stopping criterion is epoch
        int epoch = 6000 ;
        
        int cepoch = 0 ;                
        
        // training
        while ((cepoch++)<epoch) {            
            // training mode is serial - present sample one by one.
            for (int i=0;i<sampleset.length;i++) {
                float A = sampleset[i][0] ;
                float B = sampleset[i][1] ;
                float Y = sampleset[i][2] ;                                
                
                // forward propagation
                // calculate hidden layer output activation
                // or the intermediate output activation
                float z1 = sigmoid( (1.0f*wbias1) + (A*w1) + (B*w4) ) ;
                float z2 = sigmoid( (1.0f*wbias2) + (A*w2) + (B*w5) ) ;
                float z3 = sigmoid( (1.0f*wbias3) + (A*w3) + (B*w6) ) ;
                
                // calculate output activation O1
                float o1 = sigmoid( (1.0f*wbias4) + (z1*w7) + (z2*w8) + (z3*w9)) ;                
                // end forward propagation
                
                // calclate D error (in the output layer)
                float D1 = o1*(1.0f - o1)*(Y - o1) ;                               
                
                // calculate delta error (int the hidden layer)
                float d1 = z1*(1.0f - z1)*( (D1*w7) ) ;
                float d2 = z2*(1.0f - z2)*( (D1*w8) ) ;
                float d3 = z3*(1.0f - z3)*( (D1*w9) ) ;                
                
                // backward propagation (weights update)
                // output layer down to hidden layer
                w7 = w7 + n*D1*z1 ;
                w8 = w8 + n*D1*z2 ;
                w9 = w9 + n*D1*z3 ;
                wbias4 = wbias4 + n*D1*1.0f ;
                
                // hidden layer down to input layer
                w1 = w1 + n*d1*A ;
                w2 = w2 + n*d2*A ;
                w3 = w3 + n*d3*A ;
                w4 = w4 + n*d1*B ;
                w5 = w5 + n*d2*B ;
                w6 = w6 + n*d3*B ;
                wbias1 = wbias1 + n*d1*1.0f ;
                wbias2 = wbias2 + n*d2*1.0f ;
                wbias3 = wbias3 + n*d3*1.0f ;
                // completed weights update
            }
            // completed all samples
        }               
        // completed all epochs
        
        // test mode
        for (int i=0;i<sampleset.length;i++) {
            float A = sampleset[i][0] ;
            float B = sampleset[i][1] ;                      
            
            // forward propagation
            // calculate hidden layer output activation
            float z1 = sigmoid( (1.0f*wbias1) + (A*w1) + (B*w4) ) ;
            float z2 = sigmoid( (1.0f*wbias2) + (A*w2) + (B*w5) ) ;
            float z3 = sigmoid( (1.0f*wbias3) + (A*w3) + (B*w6) ) ;               

            // calculate output activation
            float o1 = sigmoid( (1.0f*wbias4) + (z1*w7) + (z2*w8) + (z3*w9)) ;            
            // end forward propagation
            
            // o1 is the prediction/estimation            
            System.out.println("!xor("+A+","+B+") = "+o1) ;
        }
                
    }
       
    public static void main(String[] args) {        
        new BackPropagation() ;
    }
    
}

