Commit 0af01b36 authored by Alexander Henkel's avatar Alexander Henkel
Browse files

work on personalization evolution

parent 8d8277f4
This diff is collapsed.
......@@ -94,6 +94,11 @@ arg_parser.add_argument('-fc', '--force_collection',
nargs='+',
help='specify collections to force retraining')
arg_parser.add_argument('-e', '--evaluate',
dest='evaluate',
action='store_true',
help='just redo predictions')
args = arg_parser.parse_args()
model_evaluation = ModelEvaluation()
......@@ -196,13 +201,15 @@ def iterate_over_collection(collection_conf, epochs):
personalizer = Personalizer()
personalizer.initialize(args.base_models_prefix + collection_conf['base_model'])
for i, dataset in enumerate(collection):
dataset.apply_pseudo_label_generators(pseudo_model_settings[pseudo_model_setting])
iterate_model_name = gen_name(collection_conf['name'], pseudo_model_setting, i, epochs)
model_info = {'generators': pseudo_label_generators_to_int(pseudo_model_settings[pseudo_model_setting]),
'iteration': i, 'collection_name': collection_conf['name'], 'epochs': epochs}
trainings_manager.add_model_information(iterate_model_name, model_info)
is_new_model = train_model([dataset, ], iterate_model_name, personalizer, epochs,
use_regularization=use_regularization, collection_name=collection_conf['name'])
is_new_model = False
if not args.evaluate:
dataset.apply_pseudo_label_generators(pseudo_model_settings[pseudo_model_setting])
model_info = {'generators': pseudo_label_generators_to_int(pseudo_model_settings[pseudo_model_setting]),
'iteration': i, 'collection_name': collection_conf['name'], 'epochs': epochs}
trainings_manager.add_model_information(iterate_model_name, model_info)
is_new_model = train_model([dataset, ], iterate_model_name, personalizer, epochs,
use_regularization=use_regularization, collection_name=collection_conf['name'])
model_evaluation.add_model(models_directory + iterate_model_name)
model_evaluation.assign_collection_to_model(get_test_collection_by_name(collection_conf['name']),
iterate_model_name)
......@@ -213,12 +220,13 @@ def iterate_over_collection(collection_conf, epochs):
model_evaluation.clear_evaluations_of_model(iterate_model_name)
single_model_name = gen_name(collection_conf['name'], pseudo_model_setting, dataset.name[:16], epochs)
model_info = {'generators': pseudo_label_generators_to_int(pseudo_model_settings[pseudo_model_setting]),
'iteration': 'single', 'collection_name': collection_conf['name'], 'based_on_iteration': i,
'epochs': epochs}
trainings_manager.add_model_information(single_model_name, model_info)
is_new_model = train_model([dataset, ], single_model_name, personalizer, epochs, is_iterative=False,
use_regularization=use_regularization)
if not args.evaluate:
model_info = {'generators': pseudo_label_generators_to_int(pseudo_model_settings[pseudo_model_setting]),
'iteration': 'single', 'collection_name': collection_conf['name'], 'based_on_iteration': i,
'epochs': epochs}
trainings_manager.add_model_information(single_model_name, model_info)
is_new_model = train_model([dataset, ], single_model_name, personalizer, epochs, is_iterative=False,
use_regularization=use_regularization)
model_evaluation.add_model(models_directory + single_model_name)
model_evaluation.assign_collection_to_model(get_test_collection_by_name(collection_conf['name']),
single_model_name)
......@@ -228,7 +236,8 @@ def iterate_over_collection(collection_conf, epochs):
personalizer = Personalizer()
personalizer.initialize(models_directory + iterate_model_name)
trainings_manager.db_update()
if not args.evaluate:
trainings_manager.db_update()
def start_training():
......@@ -243,8 +252,8 @@ def start_training():
for collection_conf in training_config['collection_configs']['recorded_collections'].values():
iterate_over_collection(collection_conf, epochs)
model_evaluation.do_predictions()
model_evaluation.save_predictions(args.PredictionsDB)
model_evaluation.do_predictions()
model_evaluation.save_predictions(args.PredictionsDB)
......@@ -254,14 +263,16 @@ if __name__ == '__main__':
synthetic_dataset_manager = DatasetManager(args.SyntheticDataset)
recorded_dataset_manager = DatasetManager(args.RecordedDataset)
trainings_manager = TrainingsManager(args.TrainingDB)
model_evaluation.load_predictions(args.PredictionsDB)
if not args.evaluate:
model_evaluation.load_predictions(args.PredictionsDB)
models_directory = args.models_dir + '/'
with open(args.ConfigurationFile, 'r') as config_file:
training_config = yaml.safe_load(config_file)
print(training_config)
if not args.skip or training_run_name not in trainings_manager.database['training_runs']:
if not args.evaluate and (not args.skip or training_run_name not in trainings_manager.database['training_runs']):
print('new training run')
trainings_manager.database['training_runs'][training_run_name] = []
......
......@@ -96,6 +96,8 @@ observed_models = []
model_predictions = dict()
global target_pseudo_model_settings
target_pseudo_model_settings = list(pseudo_model_settings.items())
def gen_name(base_name, additional_info):
model_name = f'{training_run_name}_{base_name}'
......@@ -173,8 +175,8 @@ def train_model(model_name):
if not skip_model and (not skip_existing or model_name not in trainings_manager.database['training_runs'][
training_run_name] or force_model):
print('train:', model_name)
personalizer.incremental_learn_series_pseudo(collection,
save_model_as=models_directory + model_name, epochs=100)
# personalizer.incremental_learn_series_pseudo(collection,
# save_model_as=models_directory + model_name, epochs=100)
if model_name not in trainings_manager.database['training_runs'][training_run_name]:
trainings_manager.database['training_runs'][training_run_name].append(model_name)
else:
......@@ -182,9 +184,8 @@ def train_model(model_name):
def train_on_collection(additional_info=None):
target_pseudo_model_settings = list(pseudo_model_settings.items())
if additional_info is not None:
target_pseudo_model_settings = [(key, value) for key,value in pseudo_model_settings.items() if key in thesis_filters]
# if additional_info is not None:
# target_pseudo_model_settings = [(key, value) for key,value in pseudo_model_settings.items() if key in thesis_filters]
if additional_info is None:
additional_info = dict()
......@@ -204,6 +205,7 @@ def train_on_collection(additional_info=None):
dataset.generate_feedback_areas(prediction=model_predictions[dataset.name])
for pseudo_model, model_settings in target_pseudo_model_settings:
print('Pseudo filter:', pseudo_model)
model_name = gen_name(pseudo_model, additional_info)
if len(observed_models) > 0 and pseudo_model not in observed_models:
print('skip', model_name)
......@@ -233,6 +235,17 @@ def start_training():
if 'random' in training_run_name:
target_randoms = list()
global target_pseudo_model_settings
# target_pseudo_model_selection = ['allnoise_correctedhwgt', 'allnoise_correctedscore',
# 'allnoise_correctbydeepconvfilter', 'allnoise_correctbyfcndaefilter',
# 'allnoise_correctbyconvlstmfilter', 'allnoise_correctbyconvlstm2filter',
# 'allnoise_correctbyconvlstm3filter', 'alldeepconv_correctbyconvlstm3filter',
# 'alldeepconv_correctbyconvlstm2filter6', 'alldeepconv_correctbyconvlstm3filter6']
target_pseudo_model_selection = ['alldeepconv_correctbyconvlstm3filter6']
target_pseudo_model_settings = [(key, value) for key, value in pseudo_model_settings.items() if
key in target_pseudo_model_selection]
for i in np.arange(0.2, 1.05, 0.1):
i = np.around(i, decimals=3)
target_randoms.append((i, 1.00))
......@@ -268,6 +281,7 @@ def start_training():
# evaluation_reliability_no=target_random[0],
# evaluation_reliability_yes=target_random[1],
# clear_just_covered=False)
print(dataset.indicators[1])
if dataset.name not in indicator_backups:
indicator_backups[dataset.name] = (dataset.indicators[0].copy(), dataset.indicators[1].copy())
else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment