Commit c23bf037 authored by Alexander Henkel's avatar Alexander Henkel
Browse files

removed shift in running mean

parent 9de9e63a
This source diff could not be displayed because it is too large. You can view the blob instead.
%% Cell type:code id:0b75467c tags: %% Cell type:code id:0b75467c tags:
   
``` python ``` python
%load_ext autoreload %load_ext autoreload
%autoreload 2 %autoreload 2
   
%matplotlib notebook %matplotlib notebook
``` ```
   
%% Cell type:code id:86bbf5d1 tags: %% Cell type:code id:86bbf5d1 tags:
   
``` python ``` python
import sys import sys
import os import os
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from IPython.display import display, Markdown from IPython.display import display, Markdown
import copy import copy
``` ```
   
%% Cell type:code id:e16a72d9 tags: %% Cell type:code id:e16a72d9 tags:
   
``` python ``` python
module_path = os.path.abspath(os.path.join('..')) module_path = os.path.abspath(os.path.join('..'))
os.chdir(module_path) os.chdir(module_path)
if module_path not in sys.path: if module_path not in sys.path:
sys.path.append(module_path) sys.path.append(module_path)
``` ```
   
%% Cell type:code id:45c87fbb tags: %% Cell type:code id:45c87fbb tags:
   
``` python ``` python
from personalization_tools.model_evaluation import ModelEvaluation from personalization_tools.model_evaluation import ModelEvaluation
from personalization_tools.dataset_manager import DatasetManager from personalization_tools.dataset_manager import DatasetManager
from personalization_tools.trainings_manager import TrainingsManager from personalization_tools.trainings_manager import TrainingsManager
from personalization_tools.helpers import * from personalization_tools.helpers import *
from personalization_tools.evaluation_manager import EvaluationManager from personalization_tools.evaluation_manager import EvaluationManager
``` ```
   
%% Cell type:code id:b02ac845 tags: %% Cell type:code id:b02ac845 tags:
   
``` python ``` python
evaluation_config_file = './data/cluster/smoothed_noise/evaluation_config.yaml' evaluation_config_file = './data/cluster/smoothed_noise/evaluation_config.yaml'
``` ```
   
%% Cell type:code id:739c8c21 tags: %% Cell type:code id:739c8c21 tags:
   
``` python ``` python
evaluation_manager = EvaluationManager() evaluation_manager = EvaluationManager()
evaluation_manager.load_config(evaluation_config_file) evaluation_manager.load_config(evaluation_config_file)
``` ```
   
%% Output %% Output
   
load config load config
   
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-8-5dfb4021c1c2> in <module>
1 evaluation_manager = EvaluationManager()
----> 2 evaluation_manager.load_config(evaluation_config_file)
~/gitrepos/uni/handwashing_personalizer/src/personalization_tools/evaluation_manager.py in load_config(self, config_file_path)
105
106 self.personalized_models[collection_config['name']] = list()
--> 107 for model in self.training_managers[training_db].get_all_training_runs()[collection_config['training_run_name']]:
108 # print('add model', model, 'assign', test_collection)
109 # print('assign', train_collection)
KeyError: '02_run1'
%% Cell type:code id:aac81c8a tags: %% Cell type:code id:aac81c8a tags:
   
``` python ``` python
evaluation_manager.do_predictions() evaluation_manager.do_predictions()
``` ```
   
%% Output %% Output
   
Do Predictions for: Do Predictions for:
Use device: cuda Use device: cuda
run: ('01_run1_rnrhwsv.pt', 01_generated_0) ... 0.7670552730560303 seconds run: ('01_run1_rnrhw001sv.pt', 01_generated_0) ... 0.8886113166809082 seconds
run: ('01_run1_rnrhwsv.pt', 01_generated_1) ... 0.7229199409484863 seconds run: ('01_run1_rnrhw001sv.pt', 01_generated_1) ... 0.782813549041748 seconds
run: ('01_run1_rnrhwsv1.pt', 01_generated_0) ... 0.7369263172149658 seconds run: ('01_run1_rnrhw001sv1.pt', 01_generated_0) ... 0.8040542602539062 seconds
run: ('01_run1_rnrhwsv1.pt', 01_generated_1) ... 0.705453634262085 seconds run: ('01_run1_rnrhw001sv1.pt', 01_generated_1) ... 0.7916336059570312 seconds
run: ('01_run1_rnrhwsv2.pt', 01_generated_0) ... 0.7358417510986328 seconds run: ('01_run1_rnrhw001sv2.pt', 01_generated_0) ... 0.8033273220062256 seconds
run: ('01_run1_rnrhwsv2.pt', 01_generated_1) ... 0.714977502822876 seconds run: ('01_run1_rnrhw001sv2.pt', 01_generated_1) ... 0.760101318359375 seconds
run: ('01_run1_rnrhwsv3.pt', 01_generated_0) ... 0.7406353950500488 seconds run: ('01_run1_rnrhw001sv3.pt', 01_generated_0) ... 0.818626880645752 seconds
run: ('01_run1_rnrhwsv3.pt', 01_generated_1) ... 0.707059383392334 seconds run: ('01_run1_rnrhw001sv3.pt', 01_generated_1) ... 0.7752153873443604 seconds
run: ('01_run1_rnrhwsv4.pt', 01_generated_0) ... 0.7322509288787842 seconds run: ('01_run1_rnrhw001sv4.pt', 01_generated_0) ... 0.7885987758636475 seconds
run: ('01_run1_rnrhwsv4.pt', 01_generated_1) ... 0.7054550647735596 seconds run: ('01_run1_rnrhw001sv4.pt', 01_generated_1) ... 0.7428030967712402 seconds
run: ('01_run1_rnrhwsv49.pt', 01_generated_0) ... 0.7342422008514404 seconds run: ('01_run1_rnrhw001sv49.pt', 01_generated_0) ... 0.7783234119415283 seconds
run: ('01_run1_rnrhwsv49.pt', 01_generated_1) ... 0.70754075050354 seconds run: ('01_run1_rnrhw001sv49.pt', 01_generated_1) ... 0.761707067489624 seconds
run: ('01_run1_rnrhw002sv.pt', 01_generated_0) ... 0.7300739288330078 seconds run: ('01_run1_rn001rhwsv.pt', 01_generated_0) ... 0.7711310386657715 seconds
run: ('01_run1_rnrhw002sv.pt', 01_generated_1) ... 0.7094612121582031 seconds run: ('01_run1_rn001rhwsv.pt', 01_generated_1) ... 0.8299789428710938 seconds
run: ('01_run1_rnrhw002sv1.pt', 01_generated_0) ... 0.734208345413208 seconds run: ('01_run1_rn001rhwsv1.pt', 01_generated_0) ... 0.8163266181945801 seconds
run: ('01_run1_rnrhw002sv1.pt', 01_generated_1) ... 0.7204663753509521 seconds run: ('01_run1_rn001rhwsv1.pt', 01_generated_1) ... 0.7869091033935547 seconds
run: ('01_run1_rnrhw002sv2.pt', 01_generated_0) ... 0.7356076240539551 seconds run: ('01_run1_rn001rhwsv2.pt', 01_generated_0) ... 0.7828412055969238 seconds
run: ('01_run1_rnrhw002sv2.pt', 01_generated_1) ... 0.7067642211914062 seconds run: ('01_run1_rn001rhwsv2.pt', 01_generated_1) ... 0.7577722072601318 seconds
run: ('01_run1_rnrhw002sv3.pt', 01_generated_0) ... 0.7475216388702393 seconds run: ('01_run1_rn001rhwsv3.pt', 01_generated_0) ... 0.7916994094848633 seconds
run: ('01_run1_rnrhw002sv3.pt', 01_generated_1) ... 0.7140254974365234 seconds run: ('01_run1_rn001rhwsv3.pt', 01_generated_1) ... 0.7819817066192627 seconds
run: ('01_run1_rnrhw002sv4.pt', 01_generated_0) ... 0.7531940937042236 seconds run: ('01_run1_rn001rhwsv4.pt', 01_generated_0) ... 0.7870006561279297 seconds
run: ('01_run1_rnrhw002sv4.pt', 01_generated_1) ... 0.7667331695556641 seconds run: ('01_run1_rn001rhwsv4.pt', 01_generated_1) ... 0.7712128162384033 seconds
run: ('01_run1_rnrhw002sv49.pt', 01_generated_0) ... 0.757077693939209 seconds run: ('01_run1_rn001rhwsv49.pt', 01_generated_0) ... 0.8201324939727783 seconds
run: ('01_run1_rnrhw002sv49.pt', 01_generated_1) ... 0.7193083763122559 seconds run: ('01_run1_rn001rhwsv49.pt', 01_generated_1) ... 0.821786642074585 seconds
run: ('01_run1_rn002rhwsv.pt', 01_generated_0) ... 0.735968828201294 seconds
run: ('01_run1_rn002rhwsv.pt', 01_generated_1) ... 0.7115938663482666 seconds
run: ('01_run1_rn002rhwsv1.pt', 01_generated_0) ... 0.7701573371887207 seconds
run: ('01_run1_rn002rhwsv1.pt', 01_generated_1) ... 0.7119736671447754 seconds
run: ('01_run1_rn002rhwsv2.pt', 01_generated_0) ... 0.7381885051727295 seconds
run: ('01_run1_rn002rhwsv2.pt', 01_generated_1) ... 0.7239189147949219 seconds
run: ('01_run1_rn002rhwsv3.pt', 01_generated_0) ... 0.774456262588501 seconds
run: ('01_run1_rn002rhwsv3.pt', 01_generated_1) ... 0.7135612964630127 seconds
run: ('01_run1_rn002rhwsv4.pt', 01_generated_0) ... 0.7720589637756348 seconds
run: ('01_run1_rn002rhwsv4.pt', 01_generated_1) ... 0.7112200260162354 seconds
run: ('01_run1_rn002rhwsv49.pt', 01_generated_0) ... 0.7363040447235107 seconds
run: ('01_run1_rn002rhwsv49.pt', 01_generated_1) ... 0.7125518321990967 seconds
run: ('01_run1_rnrhw005sv.pt', 01_generated_0) ... 0.7371799945831299 seconds
run: ('01_run1_rnrhw005sv.pt', 01_generated_1) ... 0.717721700668335 seconds
run: ('01_run1_rnrhw005sv1.pt', 01_generated_0) ... 0.7349185943603516 seconds
run: ('01_run1_rnrhw005sv1.pt', 01_generated_1) ... 0.7174410820007324 seconds
run: ('01_run1_rnrhw005sv2.pt', 01_generated_0) ... 0.7584567070007324 seconds
run: ('01_run1_rnrhw005sv2.pt', 01_generated_1) ... 0.7136719226837158 seconds
run: ('01_run1_rnrhw005sv3.pt', 01_generated_0) ... 0.7422447204589844 seconds
run: ('01_run1_rnrhw005sv3.pt', 01_generated_1) ... 0.710848331451416 seconds
run: ('01_run1_rnrhw005sv4.pt', 01_generated_0) ... 0.7362849712371826 seconds
run: ('01_run1_rnrhw005sv4.pt', 01_generated_1) ... 0.7083594799041748 seconds
run: ('01_run1_rnrhw005sv49.pt', 01_generated_0) ... 0.7379121780395508 seconds
run: ('01_run1_rnrhw005sv49.pt', 01_generated_1) ... 0.7134616374969482 seconds
run: ('01_run1_rn005rhwsv.pt', 01_generated_0) ... 0.7370283603668213 seconds
run: ('01_run1_rn005rhwsv.pt', 01_generated_1) ... 0.7082211971282959 seconds
run: ('01_run1_rn005rhwsv1.pt', 01_generated_0) ... 0.7370593547821045 seconds
run: ('01_run1_rn005rhwsv1.pt', 01_generated_1) ... 0.7085096836090088 seconds
run: ('01_run1_rn005rhwsv2.pt', 01_generated_0) ... 0.7400667667388916 seconds
run: ('01_run1_rn005rhwsv2.pt', 01_generated_1) ... 0.7100934982299805 seconds
run: ('01_run1_rn005rhwsv3.pt', 01_generated_0) ... 0.7403748035430908 seconds
run: ('01_run1_rn005rhwsv3.pt', 01_generated_1) ... 0.7468082904815674 seconds
run: ('01_run1_rn005rhwsv4.pt', 01_generated_0) ... 0.739534854888916 seconds
run: ('01_run1_rn005rhwsv4.pt', 01_generated_1) ... 0.7143855094909668 seconds
run: ('01_run1_rn005rhwsv49.pt', 01_generated_0) ... 0.7708370685577393 seconds
run: ('01_run1_rn005rhwsv49.pt', 01_generated_1) ... 0.7121317386627197 seconds
run: ('01_run1_rnrhw01sv.pt', 01_generated_0) ... 0.734727144241333 seconds
run: ('01_run1_rnrhw01sv.pt', 01_generated_1) ... 0.7094616889953613 seconds
run: ('01_run1_rnrhw01sv1.pt', 01_generated_0) ... 0.7458362579345703 seconds
run: ('01_run1_rnrhw01sv1.pt', 01_generated_1) ... 0.7123079299926758 seconds
run: ('01_run1_rnrhw01sv2.pt', 01_generated_0) ... 0.7504129409790039 seconds
run: ('01_run1_rnrhw01sv2.pt', 01_generated_1) ... 0.7590973377227783 seconds
run: ('01_run1_rnrhw01sv3.pt', 01_generated_0) ... 0.7786257266998291 seconds
run: ('01_run1_rnrhw01sv3.pt', 01_generated_1) ... 0.7246241569519043 seconds
run: ('01_run1_rnrhw01sv4.pt', 01_generated_0) ... 0.8344666957855225 seconds
run: ('01_run1_rnrhw01sv4.pt', 01_generated_1) ... 0.7993919849395752 seconds
run: ('01_run1_rnrhw01sv49.pt', 01_generated_0) ... 0.8080456256866455 seconds
run: ('01_run1_rnrhw01sv49.pt', 01_generated_1) ... 0.7756578922271729 seconds
run: ('01_run1_rn01rhwsv.pt', 01_generated_0) ... 0.806372880935669 seconds
run: ('01_run1_rn01rhwsv.pt', 01_generated_1) ... 0.7736902236938477 seconds
run: ('01_run1_rn01rhwsv1.pt', 01_generated_0) ... 0.8084828853607178 seconds
run: ('01_run1_rn01rhwsv1.pt', 01_generated_1) ... 0.7793536186218262 seconds
run: ('01_run1_rn01rhwsv2.pt', 01_generated_0) ... 0.7942161560058594 seconds
run: ('01_run1_rn01rhwsv2.pt', 01_generated_1) ... 0.7698228359222412 seconds
run: ('01_run1_rn01rhwsv3.pt', 01_generated_0) ... 0.794614315032959 seconds
run: ('01_run1_rn01rhwsv3.pt', 01_generated_1) ... 0.7676661014556885 seconds
run: ('01_run1_rn01rhwsv4.pt', 01_generated_0) ... 0.7977406978607178 seconds
run: ('01_run1_rn01rhwsv4.pt', 01_generated_1) ... 0.7661490440368652 seconds
run: ('01_run1_rn01rhwsv49.pt', 01_generated_0) ... 0.7928433418273926 seconds
run: ('01_run1_rn01rhwsv49.pt', 01_generated_1) ... 0.7632999420166016 seconds
saved predictions saved predictions
   
%% Cell type:code id:6686eea2 tags: %% Cell type:code id:6686eea2 tags:
   
``` python ``` python
target_collection = 'synthetic_01' target_collection = 'synthetic_01'
collection_config = evaluation_manager.get_collection_config(target_collection) collection_config = evaluation_manager.get_collection_config(target_collection)
collection = evaluation_manager.get_test_sets_of_collection(target_collection) collection = evaluation_manager.get_test_sets_of_collection(target_collection)
``` ```
   
%% Cell type:code id:c95a8cd0 tags: %% Cell type:code id:c95a8cd0 tags:
   
``` python ``` python
def plot_smoothing(collection, noise_target, noise_value,target_score): def plot_smoothing(collection, noise_target, noise_value,target_score):
training_manager = evaluation_manager.get_training_manager_of_collection(target_collection) training_manager = evaluation_manager.get_training_manager_of_collection(target_collection)
base_model = collection_config['base_model'] base_model = collection_config['base_model']
base_inc_model = collection_config['base_inc_model'] base_inc_model = collection_config['base_inc_model']
models = get_models_with_infos({noise_target: noise_value}, training_manager.get_all_information()) models = get_models_with_infos({noise_target: noise_value}, training_manager.get_all_information())
print(models) print(models)
   
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(target_score) ax.set_title(target_score)
   
avg_models = models + [base_model, base_inc_model] avg_models = models + [base_model, base_inc_model]
averages = evaluation_manager.model_evaluation.get_averages(include_models=avg_models) averages = evaluation_manager.model_evaluation.get_averages(include_models=avg_models)
   
y_vals = [] y_vals = []
x_vals = [] x_vals = []
for model in models: for model in models:
x_vals.append(training_manager.get_information(model)['smooth_value']) x_vals.append(training_manager.get_information(model)['smooth_value'])
y_vals.append(averages[model][target_score]) y_vals.append(averages[model][target_score])
   
ax.plot(x_vals, y_vals) ax.plot(x_vals, y_vals)
ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model') ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model')
ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model') ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model')
   
   
def plot_noise(collection, smooth_value, noise_target, target_score): def plot_noise(collection, smooth_value, noise_target, target_score):
training_manager = evaluation_manager.get_training_manager_of_collection(target_collection) training_manager = evaluation_manager.get_training_manager_of_collection(target_collection)
base_model = collection_config['base_model'] base_model = collection_config['base_model']
base_inc_model = collection_config['base_inc_model'] base_inc_model = collection_config['base_inc_model']
steady_val = 'rand_noise' steady_val = 'rand_noise'
if noise_target == 'rand_noise': if noise_target == 'rand_noise':
steady_val = 'rand_hw' steady_val = 'rand_hw'
models = get_models_with_infos({'smooth_value': smooth_value, steady_val: 0}, training_manager.get_all_information()) models = get_models_with_infos({'smooth_value': smooth_value, steady_val: 0}, training_manager.get_all_information())
print(models) print(models)
   
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(target_score) ax.set_title(target_score)
ax.set_xlabel(noise_target) ax.set_xlabel(noise_target)
   
avg_models = models + [base_model, base_inc_model] avg_models = models + [base_model, base_inc_model]
averages = evaluation_manager.model_evaluation.get_averages(include_models=avg_models) averages = evaluation_manager.model_evaluation.get_averages(include_models=avg_models)
   
y_vals = [] y_vals = []
x_vals = [] x_vals = []
for model in models: for model in models:
x_vals.append(training_manager.get_information(model)[noise_target]) x_vals.append(training_manager.get_information(model)[noise_target])
y_vals.append(averages[model][target_score]) y_vals.append(averages[model][target_score])
   
ax.plot(x_vals, y_vals) ax.plot(x_vals, y_vals)
ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model') ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model')
ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model') ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model')
``` ```
   
%% Cell type:code id:66b691f6 tags: %% Cell type:code id:66b691f6 tags:
   
``` python ``` python
plot_smoothing(collection, 'rand_hw',0.02, 's1') plot_smoothing(collection, 'rand_hw',0.02, 's1')
plot_smoothing(collection, 'rand_noise',0.002, 's1') plot_smoothing(collection, 'rand_noise',0.002, 's1')
``` ```
   
%% Output %% Output
   
['01_run1_rnrhw02sv.pt', '01_run1_rnrhw02sv1.pt', '01_run1_rnrhw02sv2.pt', '01_run1_rnrhw02sv3.pt', '01_run1_rnrhw02sv4.pt', '01_run1_rnrhw02sv49.pt'] ['01_run1_rnrhw02sv.pt', '01_run1_rnrhw02sv1.pt', '01_run1_rnrhw02sv2.pt', '01_run1_rnrhw02sv3.pt', '01_run1_rnrhw02sv4.pt', '01_run1_rnrhw02sv49.pt']
   
   
   
['01_run1_rn002rhwsv.pt', '01_run1_rn002rhwsv1.pt', '01_run1_rn002rhwsv2.pt', '01_run1_rn002rhwsv3.pt', '01_run1_rn002rhwsv4.pt', '01_run1_rn002rhwsv49.pt'] ['01_run1_rn002rhwsv.pt', '01_run1_rn002rhwsv1.pt', '01_run1_rn002rhwsv2.pt', '01_run1_rn002rhwsv3.pt', '01_run1_rn002rhwsv4.pt', '01_run1_rn002rhwsv49.pt']
   
   
   
%% Cell type:code id:73ae7e52 tags: %% Cell type:code id:73ae7e52 tags:
   
``` python ``` python
collection = evaluation_manager.get_collections()[0] collection = evaluation_manager.get_collections()[0]
rand_noise = set() rand_noise = set()
for info in evaluation_manager.get_training_manager_of_collection(collection).get_all_information().values(): for info in evaluation_manager.get_training_manager_of_collection(collection).get_all_information().values():
rand_noise.add(info['rand_noise']) rand_noise.add(info['rand_noise'])
print(rand_noise) print(rand_noise)
``` ```
   
%% Output %% Output
   
{0, 0.05, 0.1, 0.3, 0.01, 0.005, 0.02, 0.002} {0, 0.05, 0.1, 0.3, 0.01, 0.005, 0.02, 0.002}
   
%% Cell type:markdown id:6fca2705 tags: %% Cell type:markdown id:6fca2705 tags:
   
# Old Stuff # Old Stuff
   
%% Cell type:code id:7c7c2cbc tags: %% Cell type:code id:7c7c2cbc tags:
   
``` python ``` python
predictions_db = './data/cluster/smoothed_noise/predictions_db' predictions_db = './data/cluster/smoothed_noise/predictions_db'
dataset_db = './data/synthetic_dataset_db' dataset_db = './data/synthetic_dataset_db'
training_db = './data/cluster/smoothed_noise/training_db' training_db = './data/cluster/smoothed_noise/training_db'
models_directory = './data/cluster/smoothed_noise/models/' models_directory = './data/cluster/smoothed_noise/models/'
base_model = 'HandWashingDeepConvLSTMA_trunc_01.pt' base_model = 'HandWashingDeepConvLSTMA_trunc_01.pt'
base_inc_model = 'synthetic_inc_01.pt' base_inc_model = 'synthetic_inc_01.pt'
``` ```
   
%% Cell type:code id:7e2228e5 tags: %% Cell type:code id:7e2228e5 tags:
   
``` python ``` python
model_evaluation = ModelEvaluation() model_evaluation = ModelEvaluation()
model_evaluation.load_predictions(predictions_db) model_evaluation.load_predictions(predictions_db)
``` ```
   
%% Cell type:code id:3420099f tags: %% Cell type:code id:3420099f tags:
   
``` python ``` python
dataset_manager = DatasetManager(dataset_db) dataset_manager = DatasetManager(dataset_db)
#unseen_datasets = set(dataset_manager.filter_by_category('base_synthetic_01').keys()) - set(dataset_manager.filter_by_category('base_synthetic_01_training').keys()) #unseen_datasets = set(dataset_manager.filter_by_category('base_synthetic_01').keys()) - set(dataset_manager.filter_by_category('base_synthetic_01_training').keys())
unseen_datasets = list(dataset_manager.filter_by_category('01_test').keys()) unseen_datasets = list(dataset_manager.filter_by_category('01_test').keys())
collection = dataset_manager.get_collection(unseen_datasets) collection = dataset_manager.get_collection(unseen_datasets)
model_evaluation.clear_datasets() model_evaluation.clear_datasets()
model_evaluation.add_collection(collection) model_evaluation.add_collection(collection)
``` ```
   
%% Cell type:code id:4315d623 tags: %% Cell type:code id:4315d623 tags:
   
``` python ``` python
training_manager = TrainingsManager(training_db) training_manager = TrainingsManager(training_db)
``` ```
   
%% Cell type:code id:f83cba8b tags: %% Cell type:code id:f83cba8b tags:
   
``` python ``` python
training_manager.get_all_training_runs() training_manager.get_all_training_runs()
``` ```
   
%% Output %% Output
   
{} {}
   
%% Cell type:code id:7707ddd7 tags: %% Cell type:code id:7707ddd7 tags:
   
``` python ``` python
training_runs = training_manager.get_all_training_runs() training_runs = training_manager.get_all_training_runs()
for run in training_runs.keys(): for run in training_runs.keys():
print('add run:', run) print('add run:', run)
for model in training_runs[run]: for model in training_runs[run]:
model_evaluation.add_model(models_directory + model) model_evaluation.add_model(models_directory + model)
model_evaluation.add_model('./data/' + base_model) model_evaluation.add_model('./data/' + base_model)
model_evaluation.add_model('./data/' + base_inc_model) model_evaluation.add_model('./data/' + base_inc_model)
#model_evaluation.add_model(models_directory + 'train_base.pt') #model_evaluation.add_model(models_directory + 'train_base.pt')
``` ```
   
%% Output %% Output
   
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
NameError Traceback (most recent call last) NameError Traceback (most recent call last)
<ipython-input-15-aa506fe8d4b6> in <module> <ipython-input-15-aa506fe8d4b6> in <module>
4 for model in training_runs[run]: 4 for model in training_runs[run]:
5 model_evaluation.add_model(models_directory + model) 5 model_evaluation.add_model(models_directory + model)
----> 6 model_evaluation.add_model('./data/' + base_model) ----> 6 model_evaluation.add_model('./data/' + base_model)
7 model_evaluation.add_model('./data/' + base_inc_model) 7 model_evaluation.add_model('./data/' + base_inc_model)
8 #model_evaluation.add_model(models_directory + 'train_base.pt') 8 #model_evaluation.add_model(models_directory + 'train_base.pt')
NameError: name 'model_evaluation' is not defined NameError: name 'model_evaluation' is not defined
   
%% Cell type:code id:38aa7af0 tags: %% Cell type:code id:38aa7af0 tags:
   
``` python ``` python
# model_evaluation.clear_evaluations() # model_evaluation.clear_evaluations()
model_evaluation.do_predictions() model_evaluation.do_predictions()
``` ```
   
%% Output %% Output
   
Do Predictions for: Do Predictions for:
Use device: cuda Use device: cuda
   
%% Cell type:code id:585abe8d tags: %% Cell type:code id:585abe8d tags:
   
``` python ``` python
model_evaluation.save_predictions(predictions_db) model_evaluation.save_predictions(predictions_db)
``` ```
   
%% Output %% Output
   
saved predictions saved predictions
   
%% Cell type:code id:104cf460 tags: %% Cell type:code id:104cf460 tags:
   
``` python ``` python
evaluations = model_evaluation.get_evaluations(sort_by='S1') evaluations = model_evaluation.get_evaluations(sort_by='S1')
evaluations evaluations
``` ```
   
%% Output %% Output
   
<pandas.io.formats.style.Styler at 0x7f405fef3b50> <pandas.io.formats.style.Styler at 0x7f405fef3b50>
   
%% Cell type:code id:ee5041d9 tags: %% Cell type:code id:ee5041d9 tags:
   
``` python ``` python
def plot_selected(target_dataset, collection, target_score, target_value_name): def plot_selected(target_dataset, collection, target_score, target_value_name):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(target_score) ax.set_title(target_score)
   
y_vals = [] y_vals = []
x_vals = [] x_vals = []
for model in collection: for model in collection:
   
selection_val = training_manager.get_information(model)[target_value_name] selection_val = training_manager.get_information(model)[target_value_name]
val = getattr(model_evaluation.predictions[(model, target_dataset)].evaluation, target_score) val = getattr(model_evaluation.predictions[(model, target_dataset)].evaluation, target_score)
y_vals.append(val) y_vals.append(val)
x_vals.append(selection_val) x_vals.append(selection_val)
   
ax.plot(x_vals, y_vals, label=target_dataset) ax.plot(x_vals, y_vals, label=target_dataset)
val = getattr(model_evaluation.predictions[(base_model, target_dataset)].evaluation, target_score) val = getattr(model_evaluation.predictions[(base_model, target_dataset)].evaluation, target_score)
ax.axhline(val, color='red', linestyle='--', label='base model') ax.axhline(val, color='red', linestyle='--', label='base model')
val = getattr(model_evaluation.predictions[(base_inc_model, target_dataset)].evaluation, target_score) val = getattr(model_evaluation.predictions[(base_inc_model, target_dataset)].evaluation, target_score)
ax.axhline(val, color='green', linestyle='--', label='base inc model') ax.axhline(val, color='green', linestyle='--', label='base inc model')
val = getattr(model_evaluation.predictions[('train_base.pt', target_dataset)].evaluation, target_score) val = getattr(model_evaluation.predictions[('train_base.pt', target_dataset)].evaluation, target_score)
ax.axhline(val, color='orange', linestyle='--', label='train inc model') ax.axhline(val, color='orange', linestyle='--', label='train inc model')
ax.set_xlabel('value') ax.set_xlabel('value')
# plt.xticks(rotation=45, ha='right') # plt.xticks(rotation=45, ha='right')
fig.legend() fig.legend()
``` ```
   
%% Cell type:code id:ffd7f7bb tags: %% Cell type:code id:ffd7f7bb tags:
   
``` python ``` python
smoothed_models = training_manager.get_all_training_runs()['false_smooth_01_run1'] smoothed_models = training_manager.get_all_training_runs()['false_smooth_01_run1']
smoothed_models = sorted(smoothed_models, key=lambda m: training_manager.get_information(m)['random_value']) smoothed_models = sorted(smoothed_models, key=lambda m: training_manager.get_information(m)['random_value'])
plot_selected('01_generated_0', smoothed_models[:], 's1', 'random_value') # -1 deprecated if no smooth_value == 0.5 in list plot_selected('01_generated_0', smoothed_models[:], 's1', 'random_value') # -1 deprecated if no smooth_value == 0.5 in list
``` ```
   
%% Output %% Output
   
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
KeyError Traceback (most recent call last) KeyError Traceback (most recent call last)
<ipython-input-13-bb29f85c7fbe> in <module> <ipython-input-13-bb29f85c7fbe> in <module>
----> 1 smoothed_models = training_manager.get_all_training_runs()['false_smooth_01_run1'] ----> 1 smoothed_models = training_manager.get_all_training_runs()['false_smooth_01_run1']
2 smoothed_models = sorted(smoothed_models, key=lambda m: training_manager.get_information(m)['random_value']) 2 smoothed_models = sorted(smoothed_models, key=lambda m: training_manager.get_information(m)['random_value'])
3 plot_selected('01_generated_0', smoothed_models[:], 's1', 'random_value') # -1 deprecated if no smooth_value == 0.5 in list 3 plot_selected('01_generated_0', smoothed_models[:], 's1', 'random_value') # -1 deprecated if no smooth_value == 0.5 in list
KeyError: 'false_smooth_01_run1' KeyError: 'false_smooth_01_run1'
   
%% Cell type:code id:b0935bb5 tags: %% Cell type:code id:b0935bb5 tags:
   
``` python ``` python
for model in model_evaluation.models.keys(): for model in model_evaluation.models.keys():
print(training_manager.get_information(model)) print(training_manager.get_information(model))
``` ```
   
%% Output %% Output
   
{'smooth_value': 0, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.1, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.1, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.2, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.2, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.3, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.3, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.35, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.35, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.4, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.4, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.45, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.45, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0.49, 'rand_hw': 0.02, 'rand_noise': 0} {'smooth_value': 0.49, 'rand_hw': 0.02, 'rand_noise': 0}
{'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.02} {'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.02}
{'smooth_value': 0, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.1, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.1, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.2, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.2, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.3, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.3, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.35, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.35, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.4, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.4, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.45, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.45, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0.49, 'rand_hw': 0.05, 'rand_noise': 0} {'smooth_value': 0.49, 'rand_hw': 0.05, 'rand_noise': 0}
{'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.05} {'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.05}
{'smooth_value': 0, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.1, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.1, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.2, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.2, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.3, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.3, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.35, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.35, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.4, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.4, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.45, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.45, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0.49, 'rand_hw': 0.1, 'rand_noise': 0} {'smooth_value': 0.49, 'rand_hw': 0.1, 'rand_noise': 0}
{'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.1} {'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.1}
{'smooth_value': 0, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.1, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.1, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.2, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.2, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.3, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.3, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.35, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.35, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.4, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.4, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.45, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.45, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0.49, 'rand_hw': 0.2, 'rand_noise': 0} {'smooth_value': 0.49, 'rand_hw': 0.2, 'rand_noise': 0}
{'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.2} {'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.2}
{'smooth_value': 0, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.1, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.1, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.2, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.2, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.3, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.3, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.35, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.35, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.4, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.4, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.45, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.45, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0.49, 'rand_hw': 0.3, 'rand_noise': 0} {'smooth_value': 0.49, 'rand_hw': 0.3, 'rand_noise': 0}
{'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.1, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.2, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.3, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.35, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.4, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.45, 'rand_hw': 0, 'rand_noise': 0.3}
{'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.3} {'smooth_value': 0.49, 'rand_hw': 0, 'rand_noise': 0.3}
   
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
KeyError Traceback (most recent call last) KeyError Traceback (most recent call last)
<ipython-input-14-ea654deea349> in <module> <ipython-input-14-ea654deea349> in <module>
1 for model in model_evaluation.models.keys(): 1 for model in model_evaluation.models.keys():
----> 2 print(training_manager.get_information(model)) ----> 2 print(training_manager.get_information(model))
   
~/gitrepos/uni/handwashing_personalizer/src/personalization_tools/trainings_manager.py in get_information(self, model_name) ~/gitrepos/uni/handwashing_personalizer/src/personalization_tools/trainings_manager.py in get_information(self, model_name)
36 36
37 def get_information(self, model_name: str): 37 def get_information(self, model_name: str):
---> 38 return self.database['model_infos'][model_name] ---> 38 return self.database['model_infos'][model_name]
39 39
40 def get_all_information(self): 40 def get_all_information(self):
KeyError: 'HandWashingDeepConvLSTMA_trunc_01.pt' KeyError: 'HandWashingDeepConvLSTMA_trunc_01.pt'
   
%% Cell type:code id:4b243171 tags: %% Cell type:code id:4b243171 tags:
   
``` python ``` python
def plot_smoothing(collection, noise_target, noise_value,target_score): def plot_smoothing(collection, noise_target, noise_value,target_score):
models = get_models_with_infos({noise_target: noise_value}, training_manager.get_all_information()) models = get_models_with_infos({noise_target: noise_value}, training_manager.get_all_information())
print(models) print(models)
   
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(target_score) ax.set_title(target_score)
   
avg_models = models + [base_model, base_inc_model] avg_models = models + [base_model, base_inc_model]
averages = model_evaluation.get_averages(include_models=avg_models) averages = model_evaluation.get_averages(include_models=avg_models)
   
y_vals = [] y_vals = []
x_vals = [] x_vals = []
for model in models: for model in models:
x_vals.append(training_manager.get_information(model)['smooth_value']) x_vals.append(training_manager.get_information(model)['smooth_value'])
y_vals.append(averages[model][target_score]) y_vals.append(averages[model][target_score])
   
ax.plot(x_vals, y_vals) ax.plot(x_vals, y_vals)
ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model') ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model')
ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model') ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model')
   
   
def plot_noise(collection, smooth_value, noise_target, target_score): def plot_noise(collection, smooth_value, noise_target, target_score):
steady_val = 'rand_noise' steady_val = 'rand_noise'
if noise_target == 'rand_noise': if noise_target == 'rand_noise':
steady_val = 'rand_hw' steady_val = 'rand_hw'
models = get_models_with_infos({'smooth_value': smooth_value, steady_val: 0}, training_manager.get_all_information()) models = get_models_with_infos({'smooth_value': smooth_value, steady_val: 0}, training_manager.get_all_information())
print(models) print(models)
   
fig, ax = plt.subplots() fig, ax = plt.subplots()
ax.set_title(target_score) ax.set_title(target_score)
ax.set_xlabel(noise_target) ax.set_xlabel(noise_target)
   
avg_models = models + [base_model, base_inc_model] avg_models = models + [base_model, base_inc_model]
averages = model_evaluation.get_averages(include_models=avg_models) averages = model_evaluation.get_averages(include_models=avg_models)
   
y_vals = [] y_vals = []
x_vals = [] x_vals = []
for model in models: for model in models:
x_vals.append(training_manager.get_information(model)[noise_target]) x_vals.append(training_manager.get_information(model)[noise_target])
y_vals.append(averages[model][target_score]) y_vals.append(averages[model][target_score])
   
ax.plot(x_vals, y_vals) ax.plot(x_vals, y_vals)
ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model') ax.axhline(averages[base_model][target_score], color='red', linestyle='--', label='base model')
ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model') ax.axhline(averages[base_inc_model][target_score], color='green', linestyle='--', label='base inc model')
   
   
plot_noise(collection, 0, 'rand_noise', 'mcc') plot_noise(collection, 0, 'rand_noise', 'mcc')
``` ```
   
%% Output %% Output
   
['01_run1_rn02rhwsv.pt', '01_run1_rn05rhwsv.pt', '01_run1_rn1rhwsv.pt', '01_run1_rn2rhwsv.pt', '01_run1_rn3rhwsv.pt'] ['01_run1_rn02rhwsv.pt', '01_run1_rn05rhwsv.pt', '01_run1_rn1rhwsv.pt', '01_run1_rn2rhwsv.pt', '01_run1_rn3rhwsv.pt']
   
   
   
%% Cell type:code id:fc5b5caf tags: %% Cell type:code id:fc5b5caf tags:
   
``` python ``` python
plot_smoothing(collection, 'rand_hw',0.02, 's1') plot_smoothing(collection, 'rand_hw',0.02, 's1')
plot_smoothing(collection, 'rand_noise',0.02, 's1') plot_smoothing(collection, 'rand_noise',0.02, 's1')
``` ```
   
%% Output %% Output
   
['01_run1_rnrhw02sv.pt', '01_run1_rnrhw02sv1.pt', '01_run1_rnrhw02sv2.pt', '01_run1_rnrhw02sv3.pt', '01_run1_rnrhw02sv35.pt', '01_run1_rnrhw02sv4.pt', '01_run1_rnrhw02sv45.pt', '01_run1_rnrhw02sv49.pt'] ['01_run1_rnrhw02sv.pt', '01_run1_rnrhw02sv1.pt', '01_run1_rnrhw02sv2.pt', '01_run1_rnrhw02sv3.pt', '01_run1_rnrhw02sv35.pt', '01_run1_rnrhw02sv4.pt', '01_run1_rnrhw02sv45.pt', '01_run1_rnrhw02sv49.pt']
   
   
   
['01_run1_rn02rhwsv.pt', '01_run1_rn02rhwsv1.pt', '01_run1_rn02rhwsv2.pt', '01_run1_rn02rhwsv3.pt', '01_run1_rn02rhwsv35.pt', '01_run1_rn02rhwsv4.pt', '01_run1_rn02rhwsv45.pt', '01_run1_rn02rhwsv49.pt'] ['01_run1_rn02rhwsv.pt', '01_run1_rn02rhwsv1.pt', '01_run1_rn02rhwsv2.pt', '01_run1_rn02rhwsv3.pt', '01_run1_rn02rhwsv35.pt', '01_run1_rn02rhwsv4.pt', '01_run1_rn02rhwsv45.pt', '01_run1_rn02rhwsv49.pt']
   
   
   
%% Cell type:code id:677c9cec tags: %% Cell type:code id:677c9cec tags:
   
``` python ``` python
training_manager.get_all_information() training_manager.get_all_information()
``` ```
   
%% Output %% Output
   
{'01_run1_rnrhw02sv.pt': {'smooth_value': 0, 'rand_hw': 0.02, 'rand_noise': 0}, {'01_run1_rnrhw02sv.pt': {'smooth_value': 0, 'rand_hw': 0.02, 'rand_noise': 0},
'01_run1_rnrhw02sv1.pt': {'smooth_value': 0.1, '01_run1_rnrhw02sv1.pt': {'smooth_value': 0.1,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv2.pt': {'smooth_value': 0.2, '01_run1_rnrhw02sv2.pt': {'smooth_value': 0.2,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv3.pt': {'smooth_value': 0.3, '01_run1_rnrhw02sv3.pt': {'smooth_value': 0.3,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv35.pt': {'smooth_value': 0.35, '01_run1_rnrhw02sv35.pt': {'smooth_value': 0.35,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv4.pt': {'smooth_value': 0.4, '01_run1_rnrhw02sv4.pt': {'smooth_value': 0.4,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv45.pt': {'smooth_value': 0.45, '01_run1_rnrhw02sv45.pt': {'smooth_value': 0.45,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw02sv49.pt': {'smooth_value': 0.49, '01_run1_rnrhw02sv49.pt': {'smooth_value': 0.49,
'rand_hw': 0.02, 'rand_hw': 0.02,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rn02rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.02}, '01_run1_rn02rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.02},
'01_run1_rn02rhwsv1.pt': {'smooth_value': 0.1, '01_run1_rn02rhwsv1.pt': {'smooth_value': 0.1,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv2.pt': {'smooth_value': 0.2, '01_run1_rn02rhwsv2.pt': {'smooth_value': 0.2,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv3.pt': {'smooth_value': 0.3, '01_run1_rn02rhwsv3.pt': {'smooth_value': 0.3,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv35.pt': {'smooth_value': 0.35, '01_run1_rn02rhwsv35.pt': {'smooth_value': 0.35,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv4.pt': {'smooth_value': 0.4, '01_run1_rn02rhwsv4.pt': {'smooth_value': 0.4,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv45.pt': {'smooth_value': 0.45, '01_run1_rn02rhwsv45.pt': {'smooth_value': 0.45,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rn02rhwsv49.pt': {'smooth_value': 0.49, '01_run1_rn02rhwsv49.pt': {'smooth_value': 0.49,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.02}, 'rand_noise': 0.02},
'01_run1_rnrhw05sv.pt': {'smooth_value': 0, 'rand_hw': 0.05, 'rand_noise': 0}, '01_run1_rnrhw05sv.pt': {'smooth_value': 0, 'rand_hw': 0.05, 'rand_noise': 0},
'01_run1_rnrhw05sv1.pt': {'smooth_value': 0.1, '01_run1_rnrhw05sv1.pt': {'smooth_value': 0.1,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv2.pt': {'smooth_value': 0.2, '01_run1_rnrhw05sv2.pt': {'smooth_value': 0.2,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv3.pt': {'smooth_value': 0.3, '01_run1_rnrhw05sv3.pt': {'smooth_value': 0.3,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv35.pt': {'smooth_value': 0.35, '01_run1_rnrhw05sv35.pt': {'smooth_value': 0.35,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv4.pt': {'smooth_value': 0.4, '01_run1_rnrhw05sv4.pt': {'smooth_value': 0.4,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv45.pt': {'smooth_value': 0.45, '01_run1_rnrhw05sv45.pt': {'smooth_value': 0.45,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw05sv49.pt': {'smooth_value': 0.49, '01_run1_rnrhw05sv49.pt': {'smooth_value': 0.49,
'rand_hw': 0.05, 'rand_hw': 0.05,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rn05rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.05}, '01_run1_rn05rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.05},
'01_run1_rn05rhwsv1.pt': {'smooth_value': 0.1, '01_run1_rn05rhwsv1.pt': {'smooth_value': 0.1,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv2.pt': {'smooth_value': 0.2, '01_run1_rn05rhwsv2.pt': {'smooth_value': 0.2,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv3.pt': {'smooth_value': 0.3, '01_run1_rn05rhwsv3.pt': {'smooth_value': 0.3,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv35.pt': {'smooth_value': 0.35, '01_run1_rn05rhwsv35.pt': {'smooth_value': 0.35,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv4.pt': {'smooth_value': 0.4, '01_run1_rn05rhwsv4.pt': {'smooth_value': 0.4,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv45.pt': {'smooth_value': 0.45, '01_run1_rn05rhwsv45.pt': {'smooth_value': 0.45,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rn05rhwsv49.pt': {'smooth_value': 0.49, '01_run1_rn05rhwsv49.pt': {'smooth_value': 0.49,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.05}, 'rand_noise': 0.05},
'01_run1_rnrhw1sv.pt': {'smooth_value': 0, 'rand_hw': 0.1, 'rand_noise': 0}, '01_run1_rnrhw1sv.pt': {'smooth_value': 0, 'rand_hw': 0.1, 'rand_noise': 0},
'01_run1_rnrhw1sv1.pt': {'smooth_value': 0.1, '01_run1_rnrhw1sv1.pt': {'smooth_value': 0.1,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv2.pt': {'smooth_value': 0.2, '01_run1_rnrhw1sv2.pt': {'smooth_value': 0.2,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv3.pt': {'smooth_value': 0.3, '01_run1_rnrhw1sv3.pt': {'smooth_value': 0.3,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv35.pt': {'smooth_value': 0.35, '01_run1_rnrhw1sv35.pt': {'smooth_value': 0.35,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv4.pt': {'smooth_value': 0.4, '01_run1_rnrhw1sv4.pt': {'smooth_value': 0.4,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv45.pt': {'smooth_value': 0.45, '01_run1_rnrhw1sv45.pt': {'smooth_value': 0.45,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw1sv49.pt': {'smooth_value': 0.49, '01_run1_rnrhw1sv49.pt': {'smooth_value': 0.49,
'rand_hw': 0.1, 'rand_hw': 0.1,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rn1rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.1}, '01_run1_rn1rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.1},
'01_run1_rn1rhwsv1.pt': {'smooth_value': 0.1, '01_run1_rn1rhwsv1.pt': {'smooth_value': 0.1,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv2.pt': {'smooth_value': 0.2, '01_run1_rn1rhwsv2.pt': {'smooth_value': 0.2,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv3.pt': {'smooth_value': 0.3, '01_run1_rn1rhwsv3.pt': {'smooth_value': 0.3,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv35.pt': {'smooth_value': 0.35, '01_run1_rn1rhwsv35.pt': {'smooth_value': 0.35,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv4.pt': {'smooth_value': 0.4, '01_run1_rn1rhwsv4.pt': {'smooth_value': 0.4,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv45.pt': {'smooth_value': 0.45, '01_run1_rn1rhwsv45.pt': {'smooth_value': 0.45,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rn1rhwsv49.pt': {'smooth_value': 0.49, '01_run1_rn1rhwsv49.pt': {'smooth_value': 0.49,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.1}, 'rand_noise': 0.1},
'01_run1_rnrhw2sv.pt': {'smooth_value': 0, 'rand_hw': 0.2, 'rand_noise': 0}, '01_run1_rnrhw2sv.pt': {'smooth_value': 0, 'rand_hw': 0.2, 'rand_noise': 0},
'01_run1_rnrhw2sv1.pt': {'smooth_value': 0.1, '01_run1_rnrhw2sv1.pt': {'smooth_value': 0.1,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv2.pt': {'smooth_value': 0.2, '01_run1_rnrhw2sv2.pt': {'smooth_value': 0.2,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv3.pt': {'smooth_value': 0.3, '01_run1_rnrhw2sv3.pt': {'smooth_value': 0.3,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv35.pt': {'smooth_value': 0.35, '01_run1_rnrhw2sv35.pt': {'smooth_value': 0.35,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv4.pt': {'smooth_value': 0.4, '01_run1_rnrhw2sv4.pt': {'smooth_value': 0.4,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv45.pt': {'smooth_value': 0.45, '01_run1_rnrhw2sv45.pt': {'smooth_value': 0.45,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw2sv49.pt': {'smooth_value': 0.49, '01_run1_rnrhw2sv49.pt': {'smooth_value': 0.49,
'rand_hw': 0.2, 'rand_hw': 0.2,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rn2rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.2}, '01_run1_rn2rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.2},
'01_run1_rn2rhwsv1.pt': {'smooth_value': 0.1, '01_run1_rn2rhwsv1.pt': {'smooth_value': 0.1,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv2.pt': {'smooth_value': 0.2, '01_run1_rn2rhwsv2.pt': {'smooth_value': 0.2,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv3.pt': {'smooth_value': 0.3, '01_run1_rn2rhwsv3.pt': {'smooth_value': 0.3,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv35.pt': {'smooth_value': 0.35, '01_run1_rn2rhwsv35.pt': {'smooth_value': 0.35,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv4.pt': {'smooth_value': 0.4, '01_run1_rn2rhwsv4.pt': {'smooth_value': 0.4,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv45.pt': {'smooth_value': 0.45, '01_run1_rn2rhwsv45.pt': {'smooth_value': 0.45,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rn2rhwsv49.pt': {'smooth_value': 0.49, '01_run1_rn2rhwsv49.pt': {'smooth_value': 0.49,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.2}, 'rand_noise': 0.2},
'01_run1_rnrhw3sv.pt': {'smooth_value': 0, 'rand_hw': 0.3, 'rand_noise': 0}, '01_run1_rnrhw3sv.pt': {'smooth_value': 0, 'rand_hw': 0.3, 'rand_noise': 0},
'01_run1_rnrhw3sv1.pt': {'smooth_value': 0.1, '01_run1_rnrhw3sv1.pt': {'smooth_value': 0.1,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv2.pt': {'smooth_value': 0.2, '01_run1_rnrhw3sv2.pt': {'smooth_value': 0.2,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv3.pt': {'smooth_value': 0.3, '01_run1_rnrhw3sv3.pt': {'smooth_value': 0.3,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv35.pt': {'smooth_value': 0.35, '01_run1_rnrhw3sv35.pt': {'smooth_value': 0.35,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv4.pt': {'smooth_value': 0.4, '01_run1_rnrhw3sv4.pt': {'smooth_value': 0.4,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv45.pt': {'smooth_value': 0.45, '01_run1_rnrhw3sv45.pt': {'smooth_value': 0.45,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rnrhw3sv49.pt': {'smooth_value': 0.49, '01_run1_rnrhw3sv49.pt': {'smooth_value': 0.49,
'rand_hw': 0.3, 'rand_hw': 0.3,
'rand_noise': 0}, 'rand_noise': 0},
'01_run1_rn3rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.3}, '01_run1_rn3rhwsv.pt': {'smooth_value': 0, 'rand_hw': 0, 'rand_noise': 0.3},
'01_run1_rn3rhwsv1.pt': {'smooth_value': 0.1, '01_run1_rn3rhwsv1.pt': {'smooth_value': 0.1,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv2.pt': {'smooth_value': 0.2, '01_run1_rn3rhwsv2.pt': {'smooth_value': 0.2,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv3.pt': {'smooth_value': 0.3, '01_run1_rn3rhwsv3.pt': {'smooth_value': 0.3,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv35.pt': {'smooth_value': 0.35, '01_run1_rn3rhwsv35.pt': {'smooth_value': 0.35,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv4.pt': {'smooth_value': 0.4, '01_run1_rn3rhwsv4.pt': {'smooth_value': 0.4,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv45.pt': {'smooth_value': 0.45, '01_run1_rn3rhwsv45.pt': {'smooth_value': 0.45,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}, 'rand_noise': 0.3},
'01_run1_rn3rhwsv49.pt': {'smooth_value': 0.49, '01_run1_rn3rhwsv49.pt': {'smooth_value': 0.49,
'rand_hw': 0, 'rand_hw': 0,
'rand_noise': 0.3}} 'rand_noise': 0.3}}
   
%% Cell type:code id:c0cccf96 tags: %% Cell type:code id:c0cccf96 tags:
   
``` python ``` python
``` ```
......
...@@ -41,7 +41,7 @@ def build_simulated_running_mean(dataset, model_name, kernel_size=20, kernel_thr ...@@ -41,7 +41,7 @@ def build_simulated_running_mean(dataset, model_name, kernel_size=20, kernel_thr
kernel = np.ones(kernel_size) / kernel_size kernel = np.ones(kernel_size) / kernel_size
mean = np.convolve(prediction[:, 1], kernel, mode='same') mean = np.convolve(prediction[:, 1], kernel, mode='same')
r_mean = np.empty_like(mean) r_mean = np.empty_like(mean)
r_mean[kernel_size:] = mean[:-kernel_size] # r_mean[kernel_size:] = mean[:-kernel_size]
return r_mean, prediction return r_mean, prediction
......
...@@ -61,7 +61,7 @@ def run_randomized_learning(random_noise, random_hw): ...@@ -61,7 +61,7 @@ def run_randomized_learning(random_noise, random_hw):
new_modelname = f'{training_run_name}_rn{str(random_noise)[2:]}rhw{str(random_hw)[2:]}.pt' new_modelname = f'{training_run_name}_rn{str(random_noise)[2:]}rhw{str(random_hw)[2:]}.pt'
actual_random_values = dict() actual_random_values = dict()
average_evaluation = [0, 0, 0, 0] average_evaluation = [0, 0, 0, 0, 0]
for dataset in collection: for dataset in collection:
y_true = dataset.y_win.copy() y_true = dataset.y_win.copy()
...@@ -78,12 +78,14 @@ def run_randomized_learning(random_noise, random_hw): ...@@ -78,12 +78,14 @@ def run_randomized_learning(random_noise, random_hw):
print('Train model:', new_modelname, 'on', collection) print('Train model:', new_modelname, 'on', collection)
trainings_manager.add_model_information(new_modelname, {'random_noise': random_noise, 'random_hw': random_hw, 'evaluation': average_evaluation}) trainings_manager.add_model_information(new_modelname, {'random_noise': random_noise, 'random_hw': random_hw, 'evaluation': average_evaluation})
if new_modelname not in trainings_manager.database['training_runs'][training_run_name]: if not args.skip or new_modelname not in trainings_manager.database['training_runs'][training_run_name]:
trainings_manager.database['training_runs'][training_run_name].append(new_modelname) trainings_manager.database['training_runs'][training_run_name].append(new_modelname)
save_path = os.path.join(args.models_dir, new_modelname) save_path = os.path.join(args.models_dir, new_modelname)
personalizer.incremental_learn_series_gt(collection, save_model_as=save_path, epochs=100) personalizer.incremental_learn_series_gt(collection, save_model_as=save_path, epochs=100)
else: else:
print('skip', new_modelname) print('skip', new_modelname)
trainings_manager.db_update()
return {new_modelname: actual_random_values} return {new_modelname: actual_random_values}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment