Commit 9de9e63a authored by Alexander Henkel's avatar Alexander Henkel
Browse files

mean of randomized models

parent b0a21aaa
......@@ -51,10 +51,10 @@ arg_parser.add_argument('-s', '--skip',
args = arg_parser.parse_args()
smooth_values = [0, 0.1, 0.2, 0.3, 0.35, 0.4, 0.45, 0.49]
smooth_values = [0, 0.1, 0.2, 0.3, 0.4, 0.49]
smooth_selects = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.98]
false_smooth_randoms = [0.001, 0.005, 0.01, 0.015, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1]
randomize_values = [0, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.3]
randomize_values = [0, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.3]
personalizer = Personalizer()
dataset_manager = DatasetManager()
......@@ -190,21 +190,18 @@ def train_on_values(randomized_collection, smooth_collections, noise_values):
if training_run_name not in trainings_manager.database['training_runs']:
trainings_manager.database['training_runs'][training_run_name] = []
if training_run_name + '_all' not in trainings_manager.database['training_runs']:
trainings_manager.database['training_runs'][training_run_name + '_all'] = []
for smooth_val in smooth_values:
model_name = f'{training_run_name}_rn{str(noise_values["rand_noise"])[2:5]}rhw{str(noise_values["rand_hw"])[2:5]}sv{str(smooth_val)[2:5]}.pt'
if skip_existing and model_name in trainings_manager.database['training_runs'][training_run_name + '_all']:
if skip_existing and model_name in trainings_manager.database['training_runs'][training_run_name]:
continue
personalizer.incremental_learn_series_gt(smooth_collections[smooth_val],
save_model_as=models_directory + model_name)
trainings_manager.database['training_runs'][training_run_name + '_all'].append(model_name)
trainings_manager.database['training_runs'][training_run_name].append(model_name)
model_info = {'smooth_value': smooth_val}
model_info.update(noise_values)
trainings_manager.add_model_information(model_name, model_info)
trainings_manager.db_update()
def start_training():
......
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
from typing import Dict, List
from typing import Dict, List, Union
import yaml
import os
......@@ -124,7 +124,7 @@ class EvaluationManager:
if do_save:
self.model_evaluation.save_predictions(self.predictions_db)
def get_collection_config(self, config_name):
def get_collection_config(self, config_name) -> Union[dict, None]:
for collection_set in self.config['collection_configs']:
for collection_config in self.config['collection_configs'][collection_set].values():
if collection_config['name'] == config_name:
......@@ -134,27 +134,62 @@ class EvaluationManager:
return collection_config
return None
def get_dataset_manager_of_collection(self, collection_name):
def get_dataset_manager_of_collection(self, collection_name) -> DatasetManager:
collection_config = self.get_collection_config(collection_name)
return self.dataset_managers[self.get_config_entry('dataset_db', collection_config)]
def get_training_manager_of_collection(self, collection_name):
def get_training_manager_of_collection(self, collection_name) -> TrainingsManager:
collection_config = self.get_collection_config(collection_name)
return self.training_managers[self.get_config_entry('training_db', collection_config)]
def get_test_sets_of_collection(self, collection_name):
def get_test_sets_of_collection(self, collection_name) -> Dict[str, Dataset]:
collection_config = self.get_collection_config(collection_name)
dataset_manager = self.dataset_managers[self.get_config_entry('dataset_db', collection_config)]
return dataset_manager.filter_by_category(collection_config['test_collection_name'])
def get_train_sets_of_collection(self, collection_name):
def get_train_sets_of_collection(self, collection_name) -> Dict[str, Dataset]:
collection_config = self.get_collection_config(collection_name)
dataset_manager = self.dataset_managers[self.get_config_entry('dataset_db', collection_config)]
return dataset_manager.filter_by_category(collection_config['train_collection_name'])
def get_run_of_collection(self, collection_name):
def get_run_of_collection(self, collection_name) -> List[str]:
collection_config = self.get_collection_config(collection_name)
training_manager = self.training_managers[self.get_config_entry('training_db', collection_config)]
return training_manager.get_all_training_runs()[collection_config['training_run_name']]
def get_collections(self) -> List[str]:
collections = []
for collection_set in self.config['collection_configs']:
for collection_config in self.config['collection_configs'][collection_set].values():
collections.append(collection_config['name'])
return collections
def check(self, short=True):
print('Collections:', self.test_collections.keys())
print('Dataset managers: ', self.dataset_managers.keys())
print('Training managers: ', self.training_managers.keys())
if not short:
print('Loaded models:', self.model_evaluation.models.keys())
def inspect(self, config_file_path):
with open(config_file_path, 'r') as config_file:
config = yaml.safe_load(config_file)
for collection_set in config['collection_configs']:
for collection_config in config['collection_configs'][collection_set].values():
print(collection_config['name'])
dataset_db = self.get_config_entry('dataset_db', collection_config)
training_db = self.get_config_entry('training_db', collection_config)
base_model_prefix = self.get_config_entry('model_prefix', collection_config, './')
models_directory = self.get_config_entry('models_directory', collection_config, './')
dataset_manager = DatasetManager(dataset_db)
training_manager = TrainingsManager(training_db)
test_collection = list(dataset_manager.filter_by_category(
collection_config['test_collection_name']).values())
train_collection = list(dataset_manager.filter_by_category(
collection_config['train_collection_name']).values())
print('Training sets:', train_collection)
print('Test sets:', test_collection)
print('Runs:', training_manager.get_all_training_runs().keys())
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment