From a174c2341b6f642c9abfe0b07364956a6e1d9d6e Mon Sep 17 00:00:00 2001 From: Thomas Guillod Date: Thu, 2 Apr 2020 14:34:48 +0200 Subject: [PATCH] AnnManager example Example for ANN (matlab and python), lsq, and ga --- .../ann_example/{ => ann_data}/get_ann_data.m | 6 +- .../ann_example/ann_data/get_ann_manager.m | 52 +++++++ .../ann_example/ann_data/get_ann_param.m | 132 ++++++++++++++++++ source_ann/ann_example/get_ann_param.m | 70 ---------- source_ann/ann_example/run_ann_example.m | 72 +++------- source_ann/ann_example/run_ann_server.py | 2 - .../@MatlabPythonClient/MatlabPythonClient.m | 18 +-- .../+ann_engine/AnnEnginePythonAnn.m | 2 +- .../ann_matlab/@AnnManager/get_idx_split.m | 7 +- source_data/get_fem_ann_data_train.m | 7 +- 10 files changed, 226 insertions(+), 142 deletions(-) rename source_ann/ann_example/{ => ann_data}/get_ann_data.m (67%) create mode 100644 source_ann/ann_example/ann_data/get_ann_manager.m create mode 100644 source_ann/ann_example/ann_data/get_ann_param.m delete mode 100644 source_ann/ann_example/get_ann_param.m diff --git a/source_ann/ann_example/get_ann_data.m b/source_ann/ann_example/ann_data/get_ann_data.m similarity index 67% rename from source_ann/ann_example/get_ann_data.m rename to source_ann/ann_example/ann_data/get_ann_data.m index bdd5913..3f6ca9f 100644 --- a/source_ann/ann_example/get_ann_data.m +++ b/source_ann/ann_example/ann_data/get_ann_data.m @@ -2,13 +2,13 @@ n_sol = 10000; -inp.x_1 = 6.0+3.0.*rand(1, n_sol); +inp.x_1 = 7.0+3.0.*rand(1, n_sol); inp.x_2 = 1.0+5.0.*rand(1, n_sol); out_ref.y_1 = inp.x_1+inp.x_2+0.1.*rand(1, n_sol); out_ref.y_2 = inp.x_1-inp.x_2+0.1.*rand(1, n_sol); -out_nrm.y_1 = 11.0.*ones(1, n_sol); -out_nrm.y_2 = 4.0.*ones(1, n_sol); +out_nrm.y_1 = 12.0.*ones(1, n_sol); +out_nrm.y_2 = 5.0.*ones(1, n_sol); end \ No newline at end of file diff --git a/source_ann/ann_example/ann_data/get_ann_manager.m b/source_ann/ann_example/ann_data/get_ann_manager.m new file mode 100644 index 0000000..e4b5f8f --- /dev/null +++ b/source_ann/ann_example/ann_data/get_ann_manager.m @@ -0,0 +1,52 @@ +function get_ann_manager(ann_type) + +% name +fprintf('################## master_train : %s\n', ann_type) + +% data +[ann_input, tag_train] = get_ann_param(ann_type); +[n_sol, inp, out_ref, out_nrm] = get_ann_data(); + +% test class +fprintf('constructor\n') +obj = AnnManager(ann_input); + +fprintf('train\n') +obj.train(tag_train, n_sol, inp, out_ref, out_nrm); + +fprintf('get_fom\n') +fom = obj.get_fom(); +assert(isstruct(fom), 'invalid fom') + +fprintf('disp\n') +obj.disp(); + +fprintf('dump\n') +[ann_input, ann_data] = obj.dump(); + +fprintf('delete\n') +obj.delete(); + +fprintf('predict\n') +predict(ann_input, ann_data, n_sol, inp, out_nrm) + +fprintf('################## master_train : %s\n', ann_type) + +end + +function predict(ann_input, ann_data, n_sol, inp, out_nrm) + +obj = AnnManager(ann_input); +obj.load(ann_data); + +[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm); +assert(islogical(is_valid_tmp), 'invalid fom') +assert(isstruct(out_nrm_tmp), 'invalid fom') + +[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm); +assert(islogical(is_valid_tmp), 'invalid fom') +assert(isstruct(out_nrm_tmp), 'invalid fom') + +obj.delete(); + +end diff --git a/source_ann/ann_example/ann_data/get_ann_param.m b/source_ann/ann_example/ann_data/get_ann_param.m new file mode 100644 index 0000000..a3d2f65 --- /dev/null +++ b/source_ann/ann_example/ann_data/get_ann_param.m @@ -0,0 +1,132 @@ +function [ann_input, tag_train] = get_ann_param(ann_type) + +% var_inp +var_inp = {}; +var_inp{end+1} = struct('name', 'x_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*7.0, 'max', 1.01.*10.0); +var_inp{end+1} = struct('name', 'x_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*1.0, 'max', 1.01.*6.0); + +% var_out +var_out = {}; +var_out{end+1} = struct('name', 'y_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); +var_out{end+1} = struct('name', 'y_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); + +% split_train_test +split_train_test.ratio_train = 0.5; +split_train_test.n_train_min = 5; +split_train_test.n_test_min = 5; +split_train_test.type = 'no_overlap'; + +% split the variable +split_var = false; + +% ann_info +switch ann_type + case 'matlab_ann' + ann_info.type = ann_type; + ann_info.fct_model = @fct_model; + ann_info.fct_train = @fct_train; + case 'python_ann' + ann_info.type = ann_type; + ann_info.hostname = 'localhost'; + ann_info.port = 10000; + ann_info.timeout = 240; + case 'matlab_lsq' + ann_info.type = ann_type; + ann_info.options = struct(... + 'Display', 'off',... + 'FunctionTolerance', 1e-6,... + 'StepTolerance', 1e-6,... + 'MaxIterations', 1e3,... + 'MaxFunctionEvaluations', 10e3); + ann_info.x_value = struct(... + 'x0', [0.0 0.0 0.0 0.0 0.0 0.0],... + 'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],... + 'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]); + ann_info.fct_fit = @fct_fit; + ann_info.fct_err = @fct_err_vec; + case 'matlab_ga' + ann_info.type = ann_type; + ann_info.options = struct(... + 'Display', 'off',... + 'TolFun', 1e-6,... + 'ConstraintTolerance', 1e-3,... + 'Generations', 40,... + 'PopulationSize', 1000); + ann_info.x_value = struct(... + 'n', 6,... + 'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],... + 'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]); + ann_info.fct_fit = @fct_fit; + ann_info.fct_err = @fct_err_sum; + otherwise + error('invalid data') +end + +% assign +ann_input.var_inp = var_inp; +ann_input.var_out = var_out; +ann_input.split_train_test = split_train_test; +ann_input.split_var = split_var; +ann_input.ann_info = ann_info; + +% tag_train +tag_train = 'none'; + +end + +function model = fct_model(tag_train, n_sol, n_inp, n_out) + +assert(ischar(tag_train), 'invalid output') +assert(isfinite(n_sol), 'invalid input') +assert(isfinite(n_inp), 'invalid input') +assert(isfinite(n_out), 'invalid output') + +model = fitnet(4); +model.trainParam.min_grad = 1e-8; +model.trainParam.epochs = 300; +model.trainParam.max_fail = 25; +model.divideParam.trainRatio = 0.8; +model.divideParam.valRatio = 0.2; +model.divideParam.testRatio = 0.0; + +end + +function [model, history] = fct_train(tag_train, model, inp, out) + +assert(ischar(tag_train), 'invalid output') +[model, history] = train(model, inp, out); + +end + +function out_mat_fit = fct_fit(tag_train, x, inp_mat) + +assert(ischar(tag_train), 'invalid output'); + +x_1 = inp_mat(1, :); +x_2 = inp_mat(2, :); + +y_1 = x(1)+x(2).*x_1+x(3).*x_2; +y_2 = x(4)+x(5).*x_1+x(6).*x_2; + +out_mat_fit = [y_1 ; y_2]; + +end + +function err_vec = fct_err_vec(tag_train, x, inp_mat, out_mat_ref) + +assert(ischar(tag_train), 'invalid output') + +out_mat_fit = fct_fit(tag_train, x, inp_mat); +err_vec = out_mat_ref-out_mat_fit; +err_vec = err_vec(:); + +end + +function err = fct_err_sum(tag_train, x, inp, out) + +assert(ischar(tag_train), 'invalid output') + +err_vec = fct_err_vec(tag_train, x, inp, out); +err = sum(err_vec.^2); + +end diff --git a/source_ann/ann_example/get_ann_param.m b/source_ann/ann_example/get_ann_param.m deleted file mode 100644 index 80b16d0..0000000 --- a/source_ann/ann_example/get_ann_param.m +++ /dev/null @@ -1,70 +0,0 @@ -function [ann_input, tag_train] = get_ann_param(ann_type) - -% var_inp -var_inp = {}; -var_inp{end+1} = struct('name', 'x_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*6.0, 'max', 1.01.*9.0); -var_inp{end+1} = struct('name', 'x_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*1.0, 'max', 1.01.*6.0); - -% var_out -var_out = {}; -var_out{end+1} = struct('name', 'y_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); -var_out{end+1} = struct('name', 'y_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); - -% split_train_test -split_train_test.ratio_train = 0.5; -split_train_test.n_min = 4; -split_train_test.type = 'no_overlap'; - -% split the variable -split_var = false; - -% ann_info -switch ann_type - case 'matlab_ann' - ann_info.type = 'matlab_ann'; - ann_info.fct_model = @fct_model; - ann_info.fct_train = @fct_train; - case 'python_ann' - ann_info.type = 'python_ann'; - ann_info.hostname = 'localhost'; - ann_info.port = 10000; - ann_info.timeout = 240; - otherwise - error('invalid data') -end - -% assign -ann_input.var_inp = var_inp; -ann_input.var_out = var_out; -ann_input.split_train_test = split_train_test; -ann_input.split_var = split_var; -ann_input.ann_info = ann_info; - -% tag_train -tag_train = 'none'; - -end - -function model = fct_model(tag_train, n_sol, n_inp, n_out) - -assert(ischar(tag_train), 'invalid output') -assert(isfinite(n_sol), 'invalid input') -assert(isfinite(n_inp), 'invalid input') -assert(isfinite(n_out), 'invalid output') - -model = fitnet(8); -model.trainParam.min_grad = 1e-8; -model.trainParam.epochs = 300; -model.trainParam.max_fail = 25; -model.divideParam.trainRatio = 0.8; -model.divideParam.valRatio = 0.2; -model.divideParam.testRatio = 0.0; - -end - -function [model, history] = fct_train(tag_train, model, inp, out) - -assert(ischar(tag_train), 'invalid output') -[model, history] = train(model, inp, out); - -end \ No newline at end of file diff --git a/source_ann/ann_example/run_ann_example.m b/source_ann/ann_example/run_ann_example.m index 922ce9f..88a0e5d 100644 --- a/source_ann/ann_example/run_ann_example.m +++ b/source_ann/ann_example/run_ann_example.m @@ -1,61 +1,35 @@ function run_ann_example() addpath('../ann_matlab'); +addpath('ann_data'); +close('all') % master_train -master_train('matlab_ann') - +fprintf('AnnManager Example\n') +fprintf(' 1 - ANN regression with MATLAB Deep Learning\n') +fprintf(' 2 - ANN regression with Python Keras and TensorFlow\n') +fprintf(' 3 - MATLAB regression with nonlinear least-squares\n') +fprintf(' 4 - MATLAB regression with genetic algorithm\n') +idx = input('Enter your choice >> '); + +choice = {'matlab_ann', 'python_ann', 'matlab_lsq', 'matlab_ga'}; +choice = get_choice(choice, idx); + +if isempty(choice) + fprintf('Invalid input\n') +else + fprintf('\n') + get_ann_manager(choice) end -function master_train(ann_type) - -% name -fprintf('################## master_train\n') - -% data -[ann_input, tag_train] = get_ann_param(ann_type); -[n_sol, inp, out_ref, out_nrm] = get_ann_data(); - -% test class -fprintf('constructor\n') -obj = AnnManager(ann_input); - -fprintf('train\n') -obj.train(tag_train, n_sol, inp, out_ref, out_nrm); - -fprintf('get_fom\n') -fom = obj.get_fom(); -assert(isstruct(fom), 'invalid fom') - -fprintf('disp\n') -obj.disp(); - -fprintf('dump\n') -[ann_input, ann_data] = obj.dump(); - -fprintf('delete\n') -obj.delete(); - -fprintf('predict\n') -predict(ann_input, ann_data, n_sol, inp, out_nrm) - -fprintf('################## master_train\n') - end -function predict(ann_input, ann_data, n_sol, inp, out_nrm) +function choice = get_choice(choice, idx) -obj = AnnManager(ann_input); -obj.load(ann_data); - -[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm); -assert(islogical(is_valid_tmp), 'invalid fom') -assert(isstruct(out_nrm_tmp), 'invalid fom') - -[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm); -assert(islogical(is_valid_tmp), 'invalid fom') -assert(isstruct(out_nrm_tmp), 'invalid fom') - -obj.delete(); +if isnumeric(idx)&&(length(idx)==1)&&(idx>=1)&&(idx<=length(choice)) + choice = choice{idx}; +else + choice = []; +end end diff --git a/source_ann/ann_example/run_ann_server.py b/source_ann/ann_example/run_ann_server.py index 5a98a8a..e81d866 100644 --- a/source_ann/ann_example/run_ann_server.py +++ b/source_ann/ann_example/run_ann_server.py @@ -17,8 +17,6 @@ def fct_model(tag_train, n_sol, n_inp, n_out): keras.layers.Dense(64, activation='relu'), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(64, activation='relu'), - keras.layers.Dense(64, activation='relu'), - keras.layers.Dense(64, activation='relu'), keras.layers.Dense(activation='linear', units=n_out), ]) diff --git a/source_ann/ann_matlab/+ann_engine/@MatlabPythonClient/MatlabPythonClient.m b/source_ann/ann_matlab/+ann_engine/@MatlabPythonClient/MatlabPythonClient.m index e621758..76f24cc 100644 --- a/source_ann/ann_matlab/+ann_engine/@MatlabPythonClient/MatlabPythonClient.m +++ b/source_ann/ann_matlab/+ann_engine/@MatlabPythonClient/MatlabPythonClient.m @@ -1,10 +1,5 @@ classdef MatlabPythonClient < handle %% properties - properties (SetAccess = immutable, GetAccess = private) - hostname - port - timeout - end properties (SetAccess = private, GetAccess = private) tcp end @@ -12,10 +7,11 @@ %% init methods (Access = public) function self = MatlabPythonClient(hostname, port, timeout) - self.hostname = hostname; - self.port = port; - self.timeout = timeout; - self.tcp = tcpclient(self.hostname, self.port, 'Timeout', timeout); + try + self.tcp = tcpclient(hostname, port, 'Timeout', timeout); + catch + error('Connection failure: Python server : %s / %d', hostname, port) + end end function data_out = run(self, data_inp) @@ -27,7 +23,7 @@ methods (Access = private) function send(self, data) % dump the data - byte = MatlabPythonClient.get_serialize(data); + byte = self.get_serialize(data); % get the length n = length(byte); @@ -45,7 +41,7 @@ function send(self, data) % load the data byte = self.tcp.read(n); - data = MatlabPythonClient.get_deserialize(byte); + data = self.get_deserialize(byte); end end diff --git a/source_ann/ann_matlab/+ann_engine/AnnEnginePythonAnn.m b/source_ann/ann_matlab/+ann_engine/AnnEnginePythonAnn.m index be7f9ca..e17e151 100644 --- a/source_ann/ann_matlab/+ann_engine/AnnEnginePythonAnn.m +++ b/source_ann/ann_matlab/+ann_engine/AnnEnginePythonAnn.m @@ -16,7 +16,7 @@ self.hostname = hostname; self.port = port; self.timeout = timeout; - self.client_obj = mat_py_bridge.MatlabPythonClient(hostname, port, timeout); + self.client_obj = ann_engine.MatlabPythonClient(hostname, port, timeout); end function [model, history] = train(self, tag_train, inp, out) diff --git a/source_ann/ann_matlab/@AnnManager/get_idx_split.m b/source_ann/ann_matlab/@AnnManager/get_idx_split.m index f08ce1c..9013dd4 100644 --- a/source_ann/ann_matlab/@AnnManager/get_idx_split.m +++ b/source_ann/ann_matlab/@AnnManager/get_idx_split.m @@ -3,9 +3,6 @@ function get_idx_split(self) % init generator rng('shuffle'); -% check size -assert(self.n_sol>=self.split_train_test.n_min, 'invalid number of solutions') - % get size n_train = round(self.n_sol.*self.split_train_test.ratio_train); @@ -24,4 +21,8 @@ function get_idx_split(self) error('invalid type') end +% check size +assert(nnz(self.idx_test)>=self.split_train_test.n_test_min, 'invalid number of solutions') +assert(nnz(self.idx_train)>=self.split_train_test.n_train_min, 'invalid number of solutions') + end diff --git a/source_data/get_fem_ann_data_train.m b/source_data/get_fem_ann_data_train.m index 0e99d9f..927481d 100644 --- a/source_data/get_fem_ann_data_train.m +++ b/source_data/get_fem_ann_data_train.m @@ -37,7 +37,8 @@ % split_train_test split_train_test.ratio_train = 0.5; -split_train_test.n_min = 4; +split_train_test.n_train_min = 5; +split_train_test.n_test_min = 5; split_train_test.type = 'no_overlap'; % split the variable @@ -46,11 +47,11 @@ % ann_info switch ann_type case 'matlab_ann' - ann_info.type = 'matlab_ann'; + ann_info.type = ann_type; ann_info.fct_model = @fct_model; ann_info.fct_train = @fct_train; case 'python_ann' - ann_info.type = 'python_ann'; + ann_info.type = ann_type; ann_info.hostname = 'localhost'; ann_info.port = 10000; ann_info.timeout = 240;