Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
henkela
handwashing_personalizer
Commits
bc2a6dbd
Commit
bc2a6dbd
authored
Jul 04, 2022
by
Alexander Henkel
Browse files
more data
parent
deae672b
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/notebooks/SelfSupervisedLearning.ipynb
0 → 100644
View file @
bc2a6dbd
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "44888f7d",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2 \n",
"\n",
"%matplotlib notebook"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "71dc2d5a",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import display, Markdown\n",
"import copy\n",
"import torch\n",
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f4e26527",
"metadata": {},
"outputs": [],
"source": [
"module_path = os.path.abspath(os.path.join('..'))\n",
"os.chdir(module_path)\n",
"if module_path not in sys.path:\n",
" sys.path.append(module_path)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "46663f57",
"metadata": {},
"outputs": [],
"source": [
"from tools.load_data_sets import *\n",
"from tools.helpers import *\n",
"from tools.learner_pipeline import LearnerPipeline, Evaluation\n",
"from tools.dataset_builder import *\n",
"from tools.model_evaluation import ModelEvaluation\n",
"from tools.sensor_recorder_data_reader import SensorRecorderDataReader\n",
"from tools.dataset_manager import DatasetManager\n",
"from tools.dataset import SyntheticDataset, ManualDataset\n",
"from tools.models import HandWashingDeepConvLSTMA\n",
"from tools.metrics import sensitivity, specificity, S1_score\n",
"from sklearn.metrics import f1_score\n",
"from sklearn.utils import class_weight\n",
"from tools.personalizer import Personalizer\n",
"from tools.trainings_manager import TrainingsManager"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "32f8acdc",
"metadata": {},
"outputs": [],
"source": [
"dataset_db = './data/synthetic_dataset_db'\n",
"base_model = './data/HandWashingDeepConvLSTMA_trunc_01.pt'\n",
"collection_name = 'base_synthetic_01'\n",
"training_run_db = 'self_supervised_training_db'\n",
"models_directory = './'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "efecde7a",
"metadata": {},
"outputs": [],
"source": [
"model = HandWashingDeepConvLSTMA(input_shape=6)\n",
"model.load_state_dict(torch.load(base_model))\n",
"model = nn.Sequential(model, nn.Softmax(dim=-1))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "df6391d8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['base_synthetic_01', 'base_synthetic_01_training']\n"
]
}
],
"source": [
"dataset_manager = DatasetManager(dataset_db)\n",
"print(dataset_manager.show_categories())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "45989ca5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['01_generated_0', '01_generated_1', '01_generated_2', '01_generated_3', '01_generated_4', '01_generated_5'])\n"
]
}
],
"source": [
"collection = dataset_manager.filter_by_category(collection_name)\n",
"print(collection.keys())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "44d64cb2",
"metadata": {},
"outputs": [],
"source": [
"predicitons = dict()\n",
"for dataset in collection.values():\n",
" dataset.indicators_synthetic = None\n",
" with torch.no_grad():\n",
" x_pred = torch.Tensor(dataset.x_win)\n",
" pred = model(x_pred.to(\"cpu\")).detach().cpu().numpy()\n",
" predicitons[dataset.name] = pred"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "098da396",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
%% Cell type:code id:44888f7d tags:
```
python
%
load_ext
autoreload
%
autoreload
2
%
matplotlib
notebook
```
%% Cell type:code id:71dc2d5a tags:
```
python
import
sys
import
os
import
numpy
as
np
import
matplotlib.pyplot
as
plt
from
IPython.display
import
display
,
Markdown
import
copy
import
torch
import
pandas
as
pd
```
%% Cell type:code id:f4e26527 tags:
```
python
module_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
'..'
))
os
.
chdir
(
module_path
)
if
module_path
not
in
sys
.
path
:
sys
.
path
.
append
(
module_path
)
```
%% Cell type:code id:46663f57 tags:
```
python
from
tools.load_data_sets
import
*
from
tools.helpers
import
*
from
tools.learner_pipeline
import
LearnerPipeline
,
Evaluation
from
tools.dataset_builder
import
*
from
tools.model_evaluation
import
ModelEvaluation
from
tools.sensor_recorder_data_reader
import
SensorRecorderDataReader
from
tools.dataset_manager
import
DatasetManager
from
tools.dataset
import
SyntheticDataset
,
ManualDataset
from
tools.models
import
HandWashingDeepConvLSTMA
from
tools.metrics
import
sensitivity
,
specificity
,
S1_score
from
sklearn.metrics
import
f1_score
from
sklearn.utils
import
class_weight
from
tools.personalizer
import
Personalizer
from
tools.trainings_manager
import
TrainingsManager
```
%% Cell type:code id:32f8acdc tags:
```
python
dataset_db
=
'./data/synthetic_dataset_db'
base_model
=
'./data/HandWashingDeepConvLSTMA_trunc_01.pt'
collection_name
=
'base_synthetic_01'
training_run_db
=
'self_supervised_training_db'
models_directory
=
'./'
```
%% Cell type:code id:efecde7a tags:
```
python
model
=
HandWashingDeepConvLSTMA
(
input_shape
=
6
)
model
.
load_state_dict
(
torch
.
load
(
base_model
))
model
=
nn
.
Sequential
(
model
,
nn
.
Softmax
(
dim
=-
1
))
```
%% Cell type:code id:df6391d8 tags:
```
python
dataset_manager
=
DatasetManager
(
dataset_db
)
print
(
dataset_manager
.
show_categories
())
```
%% Output
['base_synthetic_01', 'base_synthetic_01_training']
%% Cell type:code id:45989ca5 tags:
```
python
collection
=
dataset_manager
.
filter_by_category
(
collection_name
)
print
(
collection
.
keys
())
```
%% Output
dict_keys(['01_generated_0', '01_generated_1', '01_generated_2', '01_generated_3', '01_generated_4', '01_generated_5'])
%% Cell type:code id:44d64cb2 tags:
```
python
predicitons
=
dict
()
for
dataset
in
collection
.
values
():
dataset
.
indicators_synthetic
=
None
with
torch
.
no_grad
():
x_pred
=
torch
.
Tensor
(
dataset
.
x_win
)
pred
=
model
(
x_pred
.
to
(
"cpu"
)).
detach
().
cpu
().
numpy
()
predicitons
[
dataset
.
name
]
=
pred
```
%% Cell type:code id:098da396 tags:
```
python
```
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment