module gmle.salsa.framework; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; public behavior Simplex { double ALPHA = -1.0; double BETA = 0.5; double GAMMA = 2.0; double ftol = 10e-8; //cutoff /** * Variables for stopping conditions */ int current_iteration = 0; int maximum_iterations = 10000; long start_time; double[] y; double[] psum; double[][] p; int number_parameters; int parameter_size; int highest, next_highest, lowest; double ytry, ysave; boolean isGrow = false; boolean isContract = false; boolean isReflect = false; MLEvaluator evaluator; public Simplex(String evaluatorClass, String theaters_file) { this.evaluator = new MLEvaluator(evaluatorClass, theaters_file); } public Simplex(String evaluatorClass, int number_actors) { this.evaluator = new MLEvaluator(evaluatorClass, number_actors); } public void initialize(Object[] init_params) { evaluator<-initialize(init_params) @ currentContinuation; } public void simplex(double[][] initial_parameters) { this.start_time = System.currentTimeMillis(); p = initial_parameters; parameter_size = p[0].length; psum = new double[parameter_size]; number_parameters = p.length; track_evaluate(p) @ token y_set = set_y(token); calculate_psums() @ step() : waitfor(y_set); } public void set_y(double[] results) { if (y == null || y.length == results.length) { y = new double[results.length]; for (int i = 0; i < y.length; i++) y[i] = results[i]; } else { int current = 0; for (int i = 0; i < y.length; i++) { if (i != lowest) { y[i] = results[current]; current++; } } } } public void calculate_psums() { int i; double sum; for (int j = 0; j < p[0].length; j++) { sum = 0; for (i = 0; i < p.length; i++) { sum += p[i][j]; } psum[j] = sum; } } public void step() { highest = 0; lowest = 0; for (int i = 1; i < y.length; i++) { if (y[i] < y[lowest]) lowest = i; if (y[i] > y[highest]) { next_highest = highest; highest = i; } } next_highest = highest; boolean set_lower = false; for (int i = 0; i < y.length; i++) { if (!set_lower && y[i] != y[highest]) { next_highest = i; set_lower = true; } else if (set_lower && y[i] > y[next_highest] && y[i] != y[highest]) { next_highest = i; } } System.err.println("Stepping. highest (" + highest + "): " + y[highest] + ", lowest (" + lowest + "): " + y[lowest] + ", next highest (" + next_highest + "): " + y[next_highest]); double rtol = 2.0*Math.abs(y[highest] - y[lowest])/(Math.abs(y[highest]) + Math.abs(y[lowest])); double ptol = 0.0; for (int i = 0; i < parameter_size; i++) ptol += Math.abs(p[highest][i] - p[lowest][i]); ptol /= parameter_size; if (current_iteration > maximum_iterations || (rtol < ftol && ptol < ftol)) { System.err.println("current iteration: " + current_iteration + ", maximum_iterations: " + maximum_iterations); System.err.println("rtol < ftol? " + rtol + " < " + ftol); System.err.println("ptol < ftol? " + ptol + " < " + ftol); finish(); } else { isReflect = true; self<-augment(new Double(ALPHA)); } } double[] ptry; public void augment(double factor) { System.err.println("Augmenting by: " + factor); double factor1 = (1.0 - factor)/parameter_size; double factor2 = factor1 - factor; ptry = new double[parameter_size]; for (int i = 0; i < parameter_size; i++) ptry[i] = psum[i] * factor1 - p[highest][i] * factor2; track_evaluate(ptry) @ collect(token); } public void collect(double value) { System.err.println("Collected: " + value); ytry = value; if (ytry < y[highest]) { y[highest] = ytry; for (int j = 0; j < parameter_size; j++) { psum[j] += ptry[j] - p[highest][j]; p[highest][j] = ptry[j]; } } if (isReflect) { isReflect = false; reflect(); } else if (isGrow) { isGrow = false; step(); } else if (isContract) { if (ytry >= ysave) { isContract = false; contract(); } else { step(); } } } public void contract() { System.err.println("Contracting."); int current = 0; for (int i = 0; i < p.length; i++) { if (i != lowest) { for (int j = 0; j < parameter_size; j++) { psum[j] = 0.5 * (p[i][j] + p[lowest][j]); p[i][j] = psum[j]; } current++; } } track_evaluate(p) @ set_y(token) @ step(); } public void reflect() { System.err.println("Reflecting, ytry: " + ytry + ", ysave: " + ysave + ", lowest y[" + lowest + "]: " + y[lowest] + ", next highest y[" + next_highest + "]: " + y[next_highest]); if (ytry <= y[lowest]) { isGrow = true; self<-augment(new Double(GAMMA)); } else if (ytry >= y[next_highest]) { isContract = true; ysave = y[highest]; self<-augment(new Double(BETA)); } else { step(); } } public void finish() { System.err.println("Finished in " + ((System.currentTimeMillis() - start_time)/1000.0) + " seconds."); System.err.println("\tcurrent iteration: " + current_iteration); System.err.println("\tbest evaluation: " + y[lowest]); System.err.println("\tbest parameters: "); for (int i = 0; i < p[lowest].length; i++) { System.err.println("\t\t" + p[lowest][i]); } } BufferedWriter evaluation_file; void set_evaluation_output(String filename) { try { this.evaluation_file = new BufferedWriter(new FileWriter(new File(filename))); } catch (Exception e) { System.err.println("Unable to open evaluation output file: " + filename); System.err.println(e); e.printStackTrace(); System.exit(0); } } int evaluations_done = 0; double[] track_evaluate(double[][] parameters) { join { for (int i = 0; i < parameters.length; i++) { track_evaluate(parameters[i]); } } @ combine(token) @ currentContinuation; } double[] combine(Object[] evaluate_results) { double[] dresults = new double[evaluate_results.length]; for (int i = 0; i < evaluate_results.length; i++) dresults[i] = ((Double)evaluate_results[i]).doubleValue(); return dresults; } double track_evaluate(double[] parameters) { token result = evaluator<-evaluate(parameters); write_result(result, parameters) @ currentContinuation; } double write_result(double result, double[] parameters) { if (evaluation_file != null) { evaluations_done++; try { evaluation_file.write(evaluations_done + " : " + result + " :"); for (int i = 0; i < parameters.length; i++) { evaluation_file.write(" " + parameters[i]); } evaluation_file.write("\n"); evaluation_file.flush(); } catch (Exception e) { System.err.println("Unable to write to evaluation output file"); System.err.println(e); e.printStackTrace(); System.exit(0); } } return result; } }