Skip to content

Commit

Permalink
AnnManager example
Browse files Browse the repository at this point in the history
Example for ANN (matlab and python), lsq, and ga
  • Loading branch information
otvam committed Apr 2, 2020
1 parent 24f4623 commit a174c23
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

n_sol = 10000;

inp.x_1 = 6.0+3.0.*rand(1, n_sol);
inp.x_1 = 7.0+3.0.*rand(1, n_sol);
inp.x_2 = 1.0+5.0.*rand(1, n_sol);

out_ref.y_1 = inp.x_1+inp.x_2+0.1.*rand(1, n_sol);
out_ref.y_2 = inp.x_1-inp.x_2+0.1.*rand(1, n_sol);

out_nrm.y_1 = 11.0.*ones(1, n_sol);
out_nrm.y_2 = 4.0.*ones(1, n_sol);
out_nrm.y_1 = 12.0.*ones(1, n_sol);
out_nrm.y_2 = 5.0.*ones(1, n_sol);

end
52 changes: 52 additions & 0 deletions source_ann/ann_example/ann_data/get_ann_manager.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
function get_ann_manager(ann_type)

% name
fprintf('################## master_train : %s\n', ann_type)

% data
[ann_input, tag_train] = get_ann_param(ann_type);
[n_sol, inp, out_ref, out_nrm] = get_ann_data();

% test class
fprintf('constructor\n')
obj = AnnManager(ann_input);

fprintf('train\n')
obj.train(tag_train, n_sol, inp, out_ref, out_nrm);

fprintf('get_fom\n')
fom = obj.get_fom();
assert(isstruct(fom), 'invalid fom')

fprintf('disp\n')
obj.disp();

fprintf('dump\n')
[ann_input, ann_data] = obj.dump();

fprintf('delete\n')
obj.delete();

fprintf('predict\n')
predict(ann_input, ann_data, n_sol, inp, out_nrm)

fprintf('################## master_train : %s\n', ann_type)

end

function predict(ann_input, ann_data, n_sol, inp, out_nrm)

obj = AnnManager(ann_input);
obj.load(ann_data);

[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm);
assert(islogical(is_valid_tmp), 'invalid fom')
assert(isstruct(out_nrm_tmp), 'invalid fom')

[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm);
assert(islogical(is_valid_tmp), 'invalid fom')
assert(isstruct(out_nrm_tmp), 'invalid fom')

obj.delete();

end
132 changes: 132 additions & 0 deletions source_ann/ann_example/ann_data/get_ann_param.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
function [ann_input, tag_train] = get_ann_param(ann_type)

% var_inp
var_inp = {};
var_inp{end+1} = struct('name', 'x_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*7.0, 'max', 1.01.*10.0);
var_inp{end+1} = struct('name', 'x_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*1.0, 'max', 1.01.*6.0);

% var_out
var_out = {};
var_out{end+1} = struct('name', 'y_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel');
var_out{end+1} = struct('name', 'y_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel');

% split_train_test
split_train_test.ratio_train = 0.5;
split_train_test.n_train_min = 5;
split_train_test.n_test_min = 5;
split_train_test.type = 'no_overlap';

% split the variable
split_var = false;

% ann_info
switch ann_type
case 'matlab_ann'
ann_info.type = ann_type;
ann_info.fct_model = @fct_model;
ann_info.fct_train = @fct_train;
case 'python_ann'
ann_info.type = ann_type;
ann_info.hostname = 'localhost';
ann_info.port = 10000;
ann_info.timeout = 240;
case 'matlab_lsq'
ann_info.type = ann_type;
ann_info.options = struct(...
'Display', 'off',...
'FunctionTolerance', 1e-6,...
'StepTolerance', 1e-6,...
'MaxIterations', 1e3,...
'MaxFunctionEvaluations', 10e3);
ann_info.x_value = struct(...
'x0', [0.0 0.0 0.0 0.0 0.0 0.0],...
'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],...
'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]);
ann_info.fct_fit = @fct_fit;
ann_info.fct_err = @fct_err_vec;
case 'matlab_ga'
ann_info.type = ann_type;
ann_info.options = struct(...
'Display', 'off',...
'TolFun', 1e-6,...
'ConstraintTolerance', 1e-3,...
'Generations', 40,...
'PopulationSize', 1000);
ann_info.x_value = struct(...
'n', 6,...
'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],...
'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]);
ann_info.fct_fit = @fct_fit;
ann_info.fct_err = @fct_err_sum;
otherwise
error('invalid data')
end

% assign
ann_input.var_inp = var_inp;
ann_input.var_out = var_out;
ann_input.split_train_test = split_train_test;
ann_input.split_var = split_var;
ann_input.ann_info = ann_info;

% tag_train
tag_train = 'none';

end

function model = fct_model(tag_train, n_sol, n_inp, n_out)

assert(ischar(tag_train), 'invalid output')
assert(isfinite(n_sol), 'invalid input')
assert(isfinite(n_inp), 'invalid input')
assert(isfinite(n_out), 'invalid output')

model = fitnet(4);
model.trainParam.min_grad = 1e-8;
model.trainParam.epochs = 300;
model.trainParam.max_fail = 25;
model.divideParam.trainRatio = 0.8;
model.divideParam.valRatio = 0.2;
model.divideParam.testRatio = 0.0;

end

function [model, history] = fct_train(tag_train, model, inp, out)

assert(ischar(tag_train), 'invalid output')
[model, history] = train(model, inp, out);

end

function out_mat_fit = fct_fit(tag_train, x, inp_mat)

assert(ischar(tag_train), 'invalid output');

x_1 = inp_mat(1, :);
x_2 = inp_mat(2, :);

y_1 = x(1)+x(2).*x_1+x(3).*x_2;
y_2 = x(4)+x(5).*x_1+x(6).*x_2;

out_mat_fit = [y_1 ; y_2];

end

function err_vec = fct_err_vec(tag_train, x, inp_mat, out_mat_ref)

assert(ischar(tag_train), 'invalid output')

out_mat_fit = fct_fit(tag_train, x, inp_mat);
err_vec = out_mat_ref-out_mat_fit;
err_vec = err_vec(:);

end

function err = fct_err_sum(tag_train, x, inp, out)

assert(ischar(tag_train), 'invalid output')

err_vec = fct_err_vec(tag_train, x, inp, out);
err = sum(err_vec.^2);

end
70 changes: 0 additions & 70 deletions source_ann/ann_example/get_ann_param.m

This file was deleted.

72 changes: 23 additions & 49 deletions source_ann/ann_example/run_ann_example.m
Original file line number Diff line number Diff line change
@@ -1,61 +1,35 @@
function run_ann_example()

addpath('../ann_matlab');
addpath('ann_data');
close('all')

% master_train
master_train('matlab_ann')

fprintf('AnnManager Example\n')
fprintf(' 1 - ANN regression with MATLAB Deep Learning\n')
fprintf(' 2 - ANN regression with Python Keras and TensorFlow\n')
fprintf(' 3 - MATLAB regression with nonlinear least-squares\n')
fprintf(' 4 - MATLAB regression with genetic algorithm\n')
idx = input('Enter your choice >> ');

choice = {'matlab_ann', 'python_ann', 'matlab_lsq', 'matlab_ga'};
choice = get_choice(choice, idx);

if isempty(choice)
fprintf('Invalid input\n')
else
fprintf('\n')
get_ann_manager(choice)
end

function master_train(ann_type)

% name
fprintf('################## master_train\n')

% data
[ann_input, tag_train] = get_ann_param(ann_type);
[n_sol, inp, out_ref, out_nrm] = get_ann_data();

% test class
fprintf('constructor\n')
obj = AnnManager(ann_input);

fprintf('train\n')
obj.train(tag_train, n_sol, inp, out_ref, out_nrm);

fprintf('get_fom\n')
fom = obj.get_fom();
assert(isstruct(fom), 'invalid fom')

fprintf('disp\n')
obj.disp();

fprintf('dump\n')
[ann_input, ann_data] = obj.dump();

fprintf('delete\n')
obj.delete();

fprintf('predict\n')
predict(ann_input, ann_data, n_sol, inp, out_nrm)

fprintf('################## master_train\n')

end

function predict(ann_input, ann_data, n_sol, inp, out_nrm)
function choice = get_choice(choice, idx)

obj = AnnManager(ann_input);
obj.load(ann_data);

[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm);
assert(islogical(is_valid_tmp), 'invalid fom')
assert(isstruct(out_nrm_tmp), 'invalid fom')

[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm);
assert(islogical(is_valid_tmp), 'invalid fom')
assert(isstruct(out_nrm_tmp), 'invalid fom')

obj.delete();
if isnumeric(idx)&&(length(idx)==1)&&(idx>=1)&&(idx<=length(choice))
choice = choice{idx};
else
choice = [];
end

end
2 changes: 0 additions & 2 deletions source_ann/ann_example/run_ann_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def fct_model(tag_train, n_sol, n_inp, n_out):
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(activation='linear', units=n_out),
])

Expand Down
Loading

0 comments on commit a174c23

Please sign in to comment.