Commit 12b73e04 authored by Alexander Henkel's avatar Alexander Henkel
Browse files

cluster script for participant personalization

parent c23bf037
......@@ -3,6 +3,7 @@ import copy
from typing import List, Dict
import numpy as np
from personalization_tools.helpers import Evaluation
from personalization_tools.personalizer import Personalizer
from personalization_tools.dataset_manager import DatasetManager
from personalization_tools.trainings_manager import TrainingsManager
......@@ -78,6 +79,19 @@ def randomize_collection(target_values: Dict[str, float]):
return randomized_collection
def calc_average_evaluation(evaluate_dataset):
average_evaluation = [0, 0, 0, 0, 0]
for i, dataset in enumerate(collection):
y_true = dataset.y_win
evaluation = Evaluation(y_true, evaluate_dataset[i].y_win, use_soft=True).get_values()[1:]
for i, avg_val in enumerate(evaluation):
average_evaluation[i] += avg_val
for i in range(len(average_evaluation)):
average_evaluation[i] /= len(collection)
return average_evaluation
def smooth_data(input_collection, target_collection, smooth_value):
smoothen_data = []
for i, dataset in enumerate(input_collection):
......@@ -184,20 +198,16 @@ def train_on_random_smoothed(smooth_collections):
def train_on_values(randomized_collection, smooth_collections, noise_values):
if not skip_existing:
trainings_manager.database['training_runs'][training_run_name] = []
else:
if training_run_name not in trainings_manager.database['training_runs']:
trainings_manager.database['training_runs'][training_run_name] = []
for smooth_val in smooth_values:
model_name = f'{training_run_name}_rn{str(noise_values["rand_noise"])[2:5]}rhw{str(noise_values["rand_hw"])[2:5]}sv{str(smooth_val)[2:5]}.pt'
print(model_name)
if skip_existing and model_name in trainings_manager.database['training_runs'][training_run_name]:
continue
personalizer.incremental_learn_series_gt(smooth_collections[smooth_val],
save_model_as=models_directory + model_name)
trainings_manager.database['training_runs'][training_run_name].append(model_name)
model_info = {'smooth_value': smooth_val}
average_evaluation = calc_average_evaluation(smooth_collections[smooth_val])
model_info = {'smooth_value': smooth_val, 'evaluation': average_evaluation}
model_info.update(noise_values)
trainings_manager.add_model_information(model_name, model_info)
......@@ -255,5 +265,12 @@ if __name__ == '__main__':
models_directory = args.models_dir + '/'
collection = list(dataset_manager.filter_by_category(collection_name).values())
if not skip_existing:
trainings_manager.database['training_runs'][training_run_name] = []
else:
if training_run_name not in trainings_manager.database['training_runs']:
trainings_manager.database['training_runs'][training_run_name] = []
start_training()
time.sleep(5)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -144,8 +144,9 @@ pseudo_model_settings: Dict[str, List[PseudoLabelGenerator]] = {'all': [],
common_filters = ['allnoise_correctedhwgt', 'allnoise_correctedscore', 'allnoise_correctedscore_flatten',
'allnoise_correctbydeepconvfilter',
'allnoise_correctbyfcndaefilter', 'allnoise_correctbyconvlstmfilter',
'allnoise_correctbyconvlstm2filter',
'allnoise_correctbyconvlstm3filter', 'alldeepconv_correctbyconvlstm3filter',
'alldeepconv_correctbyconvlstm3filter4']
'alldeepconv_correctbyconvlstm3filter4', 'alldeepconv_correctbyconvlstm2filter6']
thesis_filters = ['all', 'high', 'noneut', 'all_corrected_noise', 'scope_corrected_noise', 'all_corrected',
'noneut_corrected', 'allnoise_correctedhwgt', 'allnoise_correctedscore',
......
import argparse
import os
from typing import List, Union, Dict
import numpy as np
import yaml
from personalization_tools.dataset import RecordedDataset
from personalization_tools.globals import Indicators
from personalization_tools.helpers import generate_predictions
from personalization_tools.personalizer import Personalizer
from personalization_tools.sensor_recorder_data_reader import SensorRecorderDataReader
from personalization_tools.trainings_manager import TrainingsManager
arg_parser = argparse.ArgumentParser(description='Run personalization of a participant')
arg_parser.add_argument('ConfigurationFile',
metavar='config_file',
type=str,
help='path to yaml where training configurations are set')
args = arg_parser.parse_args()
config = dict()
trainings_manager: Union[None, TrainingsManager] = None
record_reader: Union[None, SensorRecorderDataReader] = None
predictions: Dict[str, np.ndarray] = dict()
def load_collection(recording_names: List[str]) -> List[RecordedDataset]:
collection = record_reader.get_collection(recording_names)
for dataset in collection:
if dataset.name not in predictions:
prediction = generate_predictions([dataset, ], config['base_model'])[dataset.name]
predictions[dataset.name] = prediction
return collection
def clean_collection(collection: List[RecordedDataset]) -> List[RecordedDataset]:
for dataset in collection[:]:
indicators = dataset.get_indicators()
num_hw_flags = np.count_nonzero(indicators[1][:, 1] == Indicators.HAND_WASH)
if num_hw_flags == 0:
collection.remove(dataset)
return collection
def training_step(train_dataset: RecordedDataset, based_model: str, personalization: dict, target_model_name: str):
personalizer = Personalizer()
personalizer.initialize(based_model)
if personalization.get('use_l2_sp', False):
personalizer.learner_pipeline.set_initial_model(config['base_model'])
target_model_path = os.path.join(config['models_directory'], target_model_name)
if config.get('enforce', False) or target_model_name not in trainings_manager.database['training_runs'][personalization['name']]:
personalizer.incremental_learn_series_pseudo([train_dataset, ], save_model_as=target_model_path,
epochs=100,
use_regularization=personalization.get('use_l2_sp', False),
freeze_feature_layers=not personalization.get('use_l2_sp', False))
information = {'based_model': based_model}
trainings_manager.add_model_information(target_model_name, information)
trainings_manager.database['training_runs'][personalization['name']].append(target_model_name)
def run_training(train_collection: List[RecordedDataset], test_collection: List[RecordedDataset], personalization: dict):
if config.get('enforce', False) or personalization['name'] not in trainings_manager.database['training_runs']:
trainings_manager.database['training_runs'][personalization['name']] = []
based_model = config['base_model']
for i, dataset in enumerate(train_collection):
dataset.generate_feedback_areas(prediction=predictions[dataset.name])
model_name = f'{personalization["name"]}_it{i}.pt'
training_step(dataset, based_model, personalization, model_name)
def start_personalization():
for personalization in config['collection_configs']['personalizations']:
test_collection = load_collection(personalization['test_sets'])
train_collection = load_collection(personalization['train_sets'])
train_collection = clean_collection(train_collection)
run_training(train_collection, test_collection, personalization)
if __name__ == '__main__':
with open(args.ConfigurationFile, 'r') as config_file:
config = yaml.safe_load(config_file)
trainings_manager = TrainingsManager(config['training_db'])
record_reader = SensorRecorderDataReader(config['recordings_directory'])
start_personalization()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment