Commit bc2a6dbd authored by Alexander Henkel's avatar Alexander Henkel
Browse files

more data

parent deae672b
%% Cell type:code id:44888f7d tags:
``` python
%load_ext autoreload
%autoreload 2
%matplotlib notebook
```
%% Cell type:code id:71dc2d5a tags:
``` python
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Markdown
import copy
import torch
import pandas as pd
```
%% Cell type:code id:f4e26527 tags:
``` python
module_path = os.path.abspath(os.path.join('..'))
os.chdir(module_path)
if module_path not in sys.path:
sys.path.append(module_path)
```
%% Cell type:code id:46663f57 tags:
``` python
from tools.load_data_sets import *
from tools.helpers import *
from tools.learner_pipeline import LearnerPipeline, Evaluation
from tools.dataset_builder import *
from tools.model_evaluation import ModelEvaluation
from tools.sensor_recorder_data_reader import SensorRecorderDataReader
from tools.dataset_manager import DatasetManager
from tools.dataset import SyntheticDataset, ManualDataset
from tools.models import HandWashingDeepConvLSTMA
from tools.metrics import sensitivity, specificity, S1_score
from sklearn.metrics import f1_score
from sklearn.utils import class_weight
from tools.personalizer import Personalizer
from tools.trainings_manager import TrainingsManager
```
%% Cell type:code id:32f8acdc tags:
``` python
dataset_db = './data/synthetic_dataset_db'
base_model = './data/HandWashingDeepConvLSTMA_trunc_01.pt'
collection_name = 'base_synthetic_01'
training_run_db = 'self_supervised_training_db'
models_directory = './'
```
%% Cell type:code id:efecde7a tags:
``` python
model = HandWashingDeepConvLSTMA(input_shape=6)
model.load_state_dict(torch.load(base_model))
model = nn.Sequential(model, nn.Softmax(dim=-1))
```
%% Cell type:code id:df6391d8 tags:
``` python
dataset_manager = DatasetManager(dataset_db)
print(dataset_manager.show_categories())
```
%% Output
['base_synthetic_01', 'base_synthetic_01_training']
%% Cell type:code id:45989ca5 tags:
``` python
collection = dataset_manager.filter_by_category(collection_name)
print(collection.keys())
```
%% Output
dict_keys(['01_generated_0', '01_generated_1', '01_generated_2', '01_generated_3', '01_generated_4', '01_generated_5'])
%% Cell type:code id:44d64cb2 tags:
``` python
predicitons = dict()
for dataset in collection.values():
dataset.indicators_synthetic = None
with torch.no_grad():
x_pred = torch.Tensor(dataset.x_win)
pred = model(x_pred.to("cpu")).detach().cpu().numpy()
predicitons[dataset.name] = pred
```
%% Cell type:code id:098da396 tags:
``` python
```
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment