-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example for ANN (matlab and python), lsq, and ga
- Loading branch information
Showing
10 changed files
with
226 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
function get_ann_manager(ann_type) | ||
|
||
% name | ||
fprintf('################## master_train : %s\n', ann_type) | ||
|
||
% data | ||
[ann_input, tag_train] = get_ann_param(ann_type); | ||
[n_sol, inp, out_ref, out_nrm] = get_ann_data(); | ||
|
||
% test class | ||
fprintf('constructor\n') | ||
obj = AnnManager(ann_input); | ||
|
||
fprintf('train\n') | ||
obj.train(tag_train, n_sol, inp, out_ref, out_nrm); | ||
|
||
fprintf('get_fom\n') | ||
fom = obj.get_fom(); | ||
assert(isstruct(fom), 'invalid fom') | ||
|
||
fprintf('disp\n') | ||
obj.disp(); | ||
|
||
fprintf('dump\n') | ||
[ann_input, ann_data] = obj.dump(); | ||
|
||
fprintf('delete\n') | ||
obj.delete(); | ||
|
||
fprintf('predict\n') | ||
predict(ann_input, ann_data, n_sol, inp, out_nrm) | ||
|
||
fprintf('################## master_train : %s\n', ann_type) | ||
|
||
end | ||
|
||
function predict(ann_input, ann_data, n_sol, inp, out_nrm) | ||
|
||
obj = AnnManager(ann_input); | ||
obj.load(ann_data); | ||
|
||
[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm); | ||
assert(islogical(is_valid_tmp), 'invalid fom') | ||
assert(isstruct(out_nrm_tmp), 'invalid fom') | ||
|
||
[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm); | ||
assert(islogical(is_valid_tmp), 'invalid fom') | ||
assert(isstruct(out_nrm_tmp), 'invalid fom') | ||
|
||
obj.delete(); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
function [ann_input, tag_train] = get_ann_param(ann_type) | ||
|
||
% var_inp | ||
var_inp = {}; | ||
var_inp{end+1} = struct('name', 'x_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*7.0, 'max', 1.01.*10.0); | ||
var_inp{end+1} = struct('name', 'x_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'min', 0.99.*1.0, 'max', 1.01.*6.0); | ||
|
||
% var_out | ||
var_out = {}; | ||
var_out{end+1} = struct('name', 'y_1', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); | ||
var_out{end+1} = struct('name', 'y_2', 'var_trf', 'lin', 'var_norm', 'min_max', 'use_nrm', true, 'var_err', 'rel'); | ||
|
||
% split_train_test | ||
split_train_test.ratio_train = 0.5; | ||
split_train_test.n_train_min = 5; | ||
split_train_test.n_test_min = 5; | ||
split_train_test.type = 'no_overlap'; | ||
|
||
% split the variable | ||
split_var = false; | ||
|
||
% ann_info | ||
switch ann_type | ||
case 'matlab_ann' | ||
ann_info.type = ann_type; | ||
ann_info.fct_model = @fct_model; | ||
ann_info.fct_train = @fct_train; | ||
case 'python_ann' | ||
ann_info.type = ann_type; | ||
ann_info.hostname = 'localhost'; | ||
ann_info.port = 10000; | ||
ann_info.timeout = 240; | ||
case 'matlab_lsq' | ||
ann_info.type = ann_type; | ||
ann_info.options = struct(... | ||
'Display', 'off',... | ||
'FunctionTolerance', 1e-6,... | ||
'StepTolerance', 1e-6,... | ||
'MaxIterations', 1e3,... | ||
'MaxFunctionEvaluations', 10e3); | ||
ann_info.x_value = struct(... | ||
'x0', [0.0 0.0 0.0 0.0 0.0 0.0],... | ||
'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],... | ||
'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]); | ||
ann_info.fct_fit = @fct_fit; | ||
ann_info.fct_err = @fct_err_vec; | ||
case 'matlab_ga' | ||
ann_info.type = ann_type; | ||
ann_info.options = struct(... | ||
'Display', 'off',... | ||
'TolFun', 1e-6,... | ||
'ConstraintTolerance', 1e-3,... | ||
'Generations', 40,... | ||
'PopulationSize', 1000); | ||
ann_info.x_value = struct(... | ||
'n', 6,... | ||
'ub', [+20.0 +20.0 +20.0 +20.0 +20.0 +20.0],... | ||
'lb', [-20.0 -20.0 -20.0 -20.0 -20.0 -20.0]); | ||
ann_info.fct_fit = @fct_fit; | ||
ann_info.fct_err = @fct_err_sum; | ||
otherwise | ||
error('invalid data') | ||
end | ||
|
||
% assign | ||
ann_input.var_inp = var_inp; | ||
ann_input.var_out = var_out; | ||
ann_input.split_train_test = split_train_test; | ||
ann_input.split_var = split_var; | ||
ann_input.ann_info = ann_info; | ||
|
||
% tag_train | ||
tag_train = 'none'; | ||
|
||
end | ||
|
||
function model = fct_model(tag_train, n_sol, n_inp, n_out) | ||
|
||
assert(ischar(tag_train), 'invalid output') | ||
assert(isfinite(n_sol), 'invalid input') | ||
assert(isfinite(n_inp), 'invalid input') | ||
assert(isfinite(n_out), 'invalid output') | ||
|
||
model = fitnet(4); | ||
model.trainParam.min_grad = 1e-8; | ||
model.trainParam.epochs = 300; | ||
model.trainParam.max_fail = 25; | ||
model.divideParam.trainRatio = 0.8; | ||
model.divideParam.valRatio = 0.2; | ||
model.divideParam.testRatio = 0.0; | ||
|
||
end | ||
|
||
function [model, history] = fct_train(tag_train, model, inp, out) | ||
|
||
assert(ischar(tag_train), 'invalid output') | ||
[model, history] = train(model, inp, out); | ||
|
||
end | ||
|
||
function out_mat_fit = fct_fit(tag_train, x, inp_mat) | ||
|
||
assert(ischar(tag_train), 'invalid output'); | ||
|
||
x_1 = inp_mat(1, :); | ||
x_2 = inp_mat(2, :); | ||
|
||
y_1 = x(1)+x(2).*x_1+x(3).*x_2; | ||
y_2 = x(4)+x(5).*x_1+x(6).*x_2; | ||
|
||
out_mat_fit = [y_1 ; y_2]; | ||
|
||
end | ||
|
||
function err_vec = fct_err_vec(tag_train, x, inp_mat, out_mat_ref) | ||
|
||
assert(ischar(tag_train), 'invalid output') | ||
|
||
out_mat_fit = fct_fit(tag_train, x, inp_mat); | ||
err_vec = out_mat_ref-out_mat_fit; | ||
err_vec = err_vec(:); | ||
|
||
end | ||
|
||
function err = fct_err_sum(tag_train, x, inp, out) | ||
|
||
assert(ischar(tag_train), 'invalid output') | ||
|
||
err_vec = fct_err_vec(tag_train, x, inp, out); | ||
err = sum(err_vec.^2); | ||
|
||
end |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,35 @@ | ||
function run_ann_example() | ||
|
||
addpath('../ann_matlab'); | ||
addpath('ann_data'); | ||
close('all') | ||
|
||
% master_train | ||
master_train('matlab_ann') | ||
|
||
fprintf('AnnManager Example\n') | ||
fprintf(' 1 - ANN regression with MATLAB Deep Learning\n') | ||
fprintf(' 2 - ANN regression with Python Keras and TensorFlow\n') | ||
fprintf(' 3 - MATLAB regression with nonlinear least-squares\n') | ||
fprintf(' 4 - MATLAB regression with genetic algorithm\n') | ||
idx = input('Enter your choice >> '); | ||
|
||
choice = {'matlab_ann', 'python_ann', 'matlab_lsq', 'matlab_ga'}; | ||
choice = get_choice(choice, idx); | ||
|
||
if isempty(choice) | ||
fprintf('Invalid input\n') | ||
else | ||
fprintf('\n') | ||
get_ann_manager(choice) | ||
end | ||
|
||
function master_train(ann_type) | ||
|
||
% name | ||
fprintf('################## master_train\n') | ||
|
||
% data | ||
[ann_input, tag_train] = get_ann_param(ann_type); | ||
[n_sol, inp, out_ref, out_nrm] = get_ann_data(); | ||
|
||
% test class | ||
fprintf('constructor\n') | ||
obj = AnnManager(ann_input); | ||
|
||
fprintf('train\n') | ||
obj.train(tag_train, n_sol, inp, out_ref, out_nrm); | ||
|
||
fprintf('get_fom\n') | ||
fom = obj.get_fom(); | ||
assert(isstruct(fom), 'invalid fom') | ||
|
||
fprintf('disp\n') | ||
obj.disp(); | ||
|
||
fprintf('dump\n') | ||
[ann_input, ann_data] = obj.dump(); | ||
|
||
fprintf('delete\n') | ||
obj.delete(); | ||
|
||
fprintf('predict\n') | ||
predict(ann_input, ann_data, n_sol, inp, out_nrm) | ||
|
||
fprintf('################## master_train\n') | ||
|
||
end | ||
|
||
function predict(ann_input, ann_data, n_sol, inp, out_nrm) | ||
function choice = get_choice(choice, idx) | ||
|
||
obj = AnnManager(ann_input); | ||
obj.load(ann_data); | ||
|
||
[is_valid_tmp, out_nrm_tmp] = obj.predict_nrm(n_sol, inp, out_nrm); | ||
assert(islogical(is_valid_tmp), 'invalid fom') | ||
assert(isstruct(out_nrm_tmp), 'invalid fom') | ||
|
||
[is_valid_tmp, out_nrm_tmp] = obj.predict_ann(n_sol, inp, out_nrm); | ||
assert(islogical(is_valid_tmp), 'invalid fom') | ||
assert(isstruct(out_nrm_tmp), 'invalid fom') | ||
|
||
obj.delete(); | ||
if isnumeric(idx)&&(length(idx)==1)&&(idx>=1)&&(idx<=length(choice)) | ||
choice = choice{idx}; | ||
else | ||
choice = []; | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.