Commit ba3f4871 authored by Alexander Henkel's avatar Alexander Henkel
Browse files

labeling

parent d3a5577c
...@@ -63,7 +63,7 @@ class Dataset: ...@@ -63,7 +63,7 @@ class Dataset:
self.pseudo_labels = PseudoLabels(self.feedback_areas) self.pseudo_labels = PseudoLabels(self.feedback_areas)
def plot(self, ax=None, plot_sensor=True, scatter=False): def plot(self, ax=None, plot_sensor=True, scatter=False, plot_indicators=True, plot_y=True):
x = self.x_data x = self.x_data
y = self.y_data y = self.y_data
indicators = self.indicators indicators = self.indicators
...@@ -72,24 +72,27 @@ class Dataset: ...@@ -72,24 +72,27 @@ class Dataset:
fig, ax = plt.subplots() fig, ax = plt.subplots()
if plot_sensor: if plot_sensor:
ax.plot(np.arange(x.shape[0]), x[:, :3], linewidth=0.5, label='acc') labels = ['x', 'y', 'z']
if scatter: ax.plot(np.arange(x.shape[0]), x[:, :3], linewidth=0.5, label=labels)
y_color = np.zeros(y.shape)
y_color[self.gt_hw] = 100 if plot_y:
cm = plt.cm.get_cmap('seismic') if scatter:
y_color = np.zeros(y.shape)
ax.scatter(np.arange(y.shape[0]), y * np.max(x[:, :3]), alpha=0.7, label='y', s=0.7, c=y_color, cmap=cm) y_color[self.gt_hw] = 100
else: cm = plt.cm.get_cmap('seismic')
ax.plot(np.arange(y.shape[0]), y*np.max(x[:, :3]), label='y')
ax.scatter(np.arange(y.shape[0]), y * np.max(x[:, :3]), alpha=0.7, label='y', s=0.7, c=y_color, cmap=cm)
else:
ax.plot(np.arange(y.shape[0]), y*np.max(x[:, :3]), label='y')
# plt.plot(np.arange(x_sen.shape[0]), x_sen[:, 0]) # plt.plot(np.arange(x_sen.shape[0]), x_sen[:, 0])
# plt.scatter(indicators[0], 1) # plt.scatter(indicators[0], 1)
if indicators: if plot_indicators and indicators:
ax.scatter(indicators[0][:]*75, np.ones((indicators[0].shape[0]))*np.max(x[:, :3])+0.02, alpha=0.5, label='manual', marker='x', c='purple') ax.scatter(indicators[0][:]*75, np.ones((indicators[0].shape[0]))*np.max(x[:, :3])+0.02, alpha=0.5, label='manual', marker='x', c='purple')
return ax return ax, fig
def plot_windows(self, ax=None, scatter=False, plot_markers=True, custom_label_names=None): def plot_windows(self, ax=None, scatter=False, plot_markers=True, custom_label_names=None):
label_names = {'y null': 'y null', 'y hw': 'y hw', 'activity': 'activity'} label_names = {'y null': 'y null', 'y hw': 'y hw', 'activity': 'ground truth activity'}
if custom_label_names is not None: if custom_label_names is not None:
label_names.update(custom_label_names) label_names.update(custom_label_names)
y_win = self.y_win y_win = self.y_win
......
...@@ -213,6 +213,10 @@ class PseudoLabels: ...@@ -213,6 +213,10 @@ class PseudoLabels:
def exclude_all(self): def exclude_all(self):
self.scope[:] = False self.scope[:] = False
def restore_hw(self):
for area in self.feedback_areas.labeled_regions_hw:
self.y_win[area[0]: area[1]] = self.y_win_base[area[0]: area[1]]
def correct_neuts_to_noise(self): def correct_neuts_to_noise(self):
for area in self.feedback_areas.labeled_regions_neut: for area in self.feedback_areas.labeled_regions_neut:
self.y_win[area[0]: area[1]] = (1, 0) self.y_win[area[0]: area[1]] = (1, 0)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment