ESP  0.1
The Example-based Sensor Predictions (ESP) system tries to bring machine learning to the maker community.
training-data-manager.h
Go to the documentation of this file.
1 
10 #pragma once
11 
12 #include <tuple>
13 #include <vector>
14 
15 #include <GRT/GRT.h>
16 
17 using std::vector;
18 
39  public:
40  // Constructor
41  TrainingDataManager(uint32_t num_classes);
42 
43  // Set the dimension of the training data
44  bool setNumDimensions(uint32_t dim);
45 
46  // Set the name of the training data
47  bool setDatasetName(const std::string name);
48 
49  // Set the name of the training data
50  bool setDatasetName(const char* const name);
51 
52  GRT::TimeSeriesClassificationData getAllData() { return data_; }
53 
54  uint32_t getNumLabels() { return num_classes_; }
55 
56  uint32_t getTotalNumSamples() { return data_.getNumSamples(); }
57 
58  // =================================================
59  // Functions that enables per-sample naming
60  // =================================================
61 
64  bool setNameForLabel(const std::string name, uint32_t label);
65  std::string getLabelName(uint32_t label);
66 
70  std::string getSampleName(uint32_t label, uint32_t index);
71  bool setSampleName(uint32_t, uint32_t, const std::string);
72 
73  // =================================================
74  // Functions that simplifies editing
75  // =================================================
76 
79  bool addSample(uint32_t label, const GRT::MatrixDouble& sample);
80 
81  uint32_t getNumSampleForLabel(uint32_t label);
82 
84  GRT::MatrixDouble getSample(uint32_t label, uint32_t index);
85 
87  bool deleteSample(uint32_t label, uint32_t index);
88 
90  bool deleteAllSamples();
91 
93  bool deleteAllSamplesWithLabel(uint32_t label);
94 
96  bool relabelSample(uint32_t label, uint32_t index, uint32_t new_label);
97 
99  bool trimSample(uint32_t label, uint32_t index, uint32_t start,
100  uint32_t end);
101 
102  // =================================================
103  // Functions that manage per-sample scores
104  // =================================================
105 
106  bool hasSampleScore(uint32_t label, uint32_t index);
107  double getSampleScore(uint32_t label, uint32_t index);
108  bool setSampleScore(uint32_t label, uint32_t index, double score);
109 
110  bool hasSampleClassLikelihoods(uint32_t label, uint32_t index);
111  vector<double> getSampleClassLikelihoods(uint32_t label, uint32_t index);
112  bool setSampleClassLikelihoods(uint32_t label, uint32_t index,
113  vector<double> likelihoods);
114 
115  // =================================================
116  // Functions for saving/loading training data
117  // =================================================
118 
119  inline bool save(const std::string& filename) {
120  return data_.save(filename);
121  }
122 
123  bool load(const std::string& filename);
124 
125  private:
126  uint32_t num_classes_;
127 
128  // Name simulates Option<std::string> type. If `Name.first` is true, then
129  // the name is valid; else use the default name.
130  using Name = std::pair<bool, std::string>;
131  vector<vector<Name>> training_sample_names_;
132  vector<std::string> default_label_names_;
133 
134  // Score simulates Option<double> type. If `Score.first` is true, then
135  // the score is valid; else use the default score.
136  using Score = std::pair<bool, double>;
137  vector<vector<Score>> training_sample_scores_;
138 
139  // Probability that the sample belongs to each class (e.g. when the model
140  // is trained on all other samples). If ClassLikelihoods.first is true,
141  // then vector is valid.
142  using ClassLikelihoods = std::pair<bool, vector<double>>;
143  vector<vector<ClassLikelihoods>> training_sample_class_likelihoods_;
144 
145  // This variable tracks the number of samples for each label. Although We
146  // can get the number with TimeSeriesClassificationData::getClassData and
147  // then getNumSamples. Caching the information here helps with bound checks!
148  //
149  // The use of this class requires index not beyond the
150  // num_samples_per_label_.
151  vector<uint32_t> num_samples_per_label_;
152 
153  // The underlying data store backed up by GRT's TimeSeriesClassificationData
154  GRT::TimeSeriesClassificationData data_;
155 
156  // Disallow copy and assign
158  void operator=(TrainingDataManager) = delete;
159 };
uint32_t getNumLabels()
Definition: training-data-manager.h:54
bool setNameForLabel(const std::string name, uint32_t label)
This will modify the default name for this label, changing it from "Label X" to name.
Definition: training-data-manager.cpp:88
vector< double > getSampleClassLikelihoods(uint32_t label, uint32_t index)
Definition: training-data-manager.cpp:236
bool relabelSample(uint32_t label, uint32_t index, uint32_t new_label)
Relabel a sample from label to new_label.
Definition: training-data-manager.cpp:161
bool setNumDimensions(uint32_t dim)
Definition: training-data-manager.cpp:31
uint32_t getNumSampleForLabel(uint32_t label)
Definition: training-data-manager.cpp:101
uint32_t getTotalNumSamples()
Definition: training-data-manager.h:56
std::string getSampleName(uint32_t label, uint32_t index)
Format the sample name. Default label name is "Label X", and the sample name is "Label X [Y]" Default...
Definition: training-data-manager.cpp:46
bool setSampleClassLikelihoods(uint32_t label, uint32_t index, vector< double > likelihoods)
Definition: training-data-manager.cpp:248
bool setSampleScore(uint32_t label, uint32_t index, double score)
Definition: training-data-manager.cpp:218
std::string getLabelName(uint32_t label)
Definition: training-data-manager.cpp:83
bool hasSampleScore(uint32_t label, uint32_t index)
Definition: training-data-manager.cpp:198
bool setDatasetName(const std::string name)
Definition: training-data-manager.cpp:35
bool hasSampleClassLikelihoods(uint32_t label, uint32_t index)
Definition: training-data-manager.cpp:228
TrainingDataManager class encloses GRT::TimeSeriesClassificationData and improves upon by adding util...
Definition: training-data-manager.h:38
GRT::TimeSeriesClassificationData getAllData()
Definition: training-data-manager.h:52
bool load(const std::string &filename)
Definition: training-data-manager.cpp:258
bool deleteAllSamples()
Remove all samples.
Definition: training-data-manager.cpp:138
bool save(const std::string &filename)
Definition: training-data-manager.h:119
bool deleteAllSamplesWithLabel(uint32_t label)
Remove all samples.
Definition: training-data-manager.cpp:150
bool setSampleName(uint32_t, uint32_t, const std::string)
Definition: training-data-manager.cpp:58
TrainingDataManager(uint32_t num_classes)
Definition: training-data-manager.cpp:17
double getSampleScore(uint32_t label, uint32_t index)
Definition: training-data-manager.cpp:206
bool addSample(uint32_t label, const GRT::MatrixDouble &sample)
Add new sample. Returns false if the label is larger than configured number of classes.
Definition: training-data-manager.cpp:68
bool trimSample(uint32_t label, uint32_t index, uint32_t start, uint32_t end)
Trim sample. What&#39;s left will be [start, end], closed interval.
Definition: training-data-manager.cpp:174
bool deleteSample(uint32_t label, uint32_t index)
Remove sample by label and the index.
Definition: training-data-manager.cpp:108
GRT::MatrixDouble getSample(uint32_t label, uint32_t index)
Get the sample by label and index.
Definition: training-data-manager.cpp:95