mexopencv  0.1
mex interface for opencv library
mexopencv_ml.cpp
Go to the documentation of this file.
1 
7 #include "mexopencv_ml.hpp"
8 using std::vector;
9 using std::string;
10 using namespace cv;
11 using namespace cv::ml;
12 
13 
14 // ==================== XXX ====================
15 
18  ("Row", cv::ml::ROW_SAMPLE) // each training sample is a row of samples
19  ("Col", cv::ml::COL_SAMPLE); // each training sample occupies a column of samples
20 
23  ("Numerical", cv::ml::VAR_NUMERICAL) // same as VAR_ORDERED
24  ("Ordered", cv::ml::VAR_ORDERED) // ordered variables
25  ("Categorical", cv::ml::VAR_CATEGORICAL) // categorical variables
26  ("N", cv::ml::VAR_NUMERICAL) // shorthand for (N)umerical
27  ("O", cv::ml::VAR_ORDERED) // shorthand for (O)rdered
28  ("C", cv::ml::VAR_CATEGORICAL); // shorthand for (C)ategorical
29 
30 
31 // ==================== XXX ====================
32 
33 MxArray toStruct(const vector<DTrees::Node>& nodes)
34 {
35  const char* fields[] = {"value", "classIdx", "parent", "left", "right",
36  "defaultDir", "split"};
37  MxArray s = MxArray::Struct(fields, 7, 1, nodes.size());
38  for (size_t i=0; i<nodes.size(); ++i) {
39  s.set("value", nodes[i].value, i);
40  s.set("classIdx", nodes[i].classIdx, i);
41  s.set("parent", nodes[i].parent, i);
42  s.set("left", nodes[i].left, i);
43  s.set("right", nodes[i].right, i);
44  s.set("defaultDir", nodes[i].defaultDir, i);
45  s.set("split", nodes[i].split, i);
46  }
47  return s;
48 }
49 
50 MxArray toStruct(const vector<DTrees::Split>& splits)
51 {
52  const char* fields[] = {"varIdx", "inversed", "quality", "next", "c",
53  "subsetOfs"};
54  MxArray s = MxArray::Struct(fields, 6, 1, splits.size());
55  for (size_t i=0; i<splits.size(); ++i) {
56  s.set("varIdx", splits[i].varIdx, i);
57  s.set("inversed", splits[i].inversed, i);
58  s.set("quality", splits[i].quality, i);
59  s.set("next", splits[i].next, i);
60  s.set("c", splits[i].c, i);
61  s.set("subsetOfs", splits[i].subsetOfs, i);
62  }
63  return s;
64 }
65 
66 
67 // ==================== XXX ====================
68 
69 Ptr<TrainData> createTrainData(
70  const Mat& samples, const Mat& responses,
71  vector<MxArray>::const_iterator first,
72  vector<MxArray>::const_iterator last)
73 {
74  nargchk((std::distance(first, last) % 2) == 0);
75  int layout = cv::ml::ROW_SAMPLE;
76  Mat varIdx, sampleIdx, sampleWeights, varType;
77  Mat missing; //TODO: currently not possible through TrainData interface
78  int splitCount = -1; // [0, nsamples)
79  double splitRatio = -1.0; // [0.0, 1.0)
80  bool splitShuffle = true;
81  for (; first != last; first += 2) {
82  string key(first->toString());
83  const MxArray& val = *(first + 1);
84  if (key == "Layout")
85  layout = SampleTypesMap[val.toString()];
86  else if (key == "VarIdx")
87  varIdx = val.toMat(
88  (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
89  else if (key == "SampleIdx")
90  sampleIdx = val.toMat(
91  (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
92  else if (key == "SampleWeights")
93  sampleWeights = val.toMat(CV_32F);
94  else if (key == "VarType") {
95  if (val.isCell()) {
96  vector<string> vtypes(val.toVector<string>());
97  varType.create(1, vtypes.size(), CV_8U);
98  for (size_t idx = 0; idx < vtypes.size(); idx++)
99  varType.at<uchar>(idx) = VariableTypeMap[vtypes[idx]];
100  }
101  else if (val.isChar()) {
102  string str(val.toString());
103  varType.create(1, str.size(), CV_8U);
104  for (size_t idx = 0; idx < str.size(); idx++)
105  varType.at<uchar>(idx) = VariableTypeMap[string(1,str[idx])];
106  }
107  else if (val.isNumeric())
108  varType = val.toMat(CV_8U);
109  else
110  mexErrMsgIdAndTxt("mexopencv:error", "Invalid VarType value");
111  }
112  else if (key == "MissingMask")
113  missing = val.toMat(CV_8U); //TODO: unused, see TrainData::setData
114  else if (key == "TrainTestSplitCount")
115  splitCount = val.toInt();
116  else if (key == "TrainTestSplitRatio")
117  splitRatio = val.toDouble();
118  else if (key == "TrainTestSplitShuffle")
119  splitShuffle = val.toBool();
120  else
121  mexErrMsgIdAndTxt("mexopencv:error",
122  "Unrecognized option %s", key.c_str());
123  }
124  Ptr<TrainData> p = TrainData::create(samples, layout, responses,
125  varIdx, sampleIdx, sampleWeights, varType);
126  if (splitCount >= 0)
127  p->setTrainTestSplit(splitCount, splitShuffle);
128  else if (splitRatio >= 0)
129  p->setTrainTestSplitRatio(splitRatio, splitShuffle);
130  return p;
131 }
132 
133 Ptr<TrainData> loadTrainData(const string& filename,
134  vector<MxArray>::const_iterator first,
135  vector<MxArray>::const_iterator last)
136 {
137  nargchk((std::distance(first, last) % 2) == 0);
138  int headerLineCount = 1;
139  int responseStartIdx = -1;
140  int responseEndIdx = -1;
141  string varTypeSpec;
142  char delimiter = ',';
143  char missch = '?';
144  int splitCount = -1; // [0, nsamples)
145  double splitRatio = -1.0; // [0.0, 1.0)
146  bool splitShuffle = true;
147  for (; first != last; first += 2) {
148  string key(first->toString());
149  const MxArray& val = *(first + 1);
150  if (key == "HeaderLineCount")
151  headerLineCount = val.toInt();
152  else if (key == "ResponseStartIdx")
153  responseStartIdx = val.toInt();
154  else if (key == "ResponseEndIdx")
155  responseEndIdx = val.toInt();
156  else if (key == "VarTypeSpec")
157  varTypeSpec = val.toString();
158  else if (key == "Delimiter")
159  delimiter = (!val.isEmpty()) ? val.toString()[0] : ' ';
160  else if (key == "Missing")
161  missch = (!val.isEmpty()) ? val.toString()[0] : '?';
162  else if (key == "TrainTestSplitCount")
163  splitCount = val.toInt();
164  else if (key == "TrainTestSplitRatio")
165  splitRatio = val.toDouble();
166  else if (key == "TrainTestSplitShuffle")
167  splitShuffle = val.toBool();
168  else
169  mexErrMsgIdAndTxt("mexopencv:error",
170  "Unrecognized option %s", key.c_str());
171  }
172  Ptr<TrainData> p = TrainData::loadFromCSV(filename, headerLineCount,
173  responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch);
174  if (p.empty())
175  mexErrMsgIdAndTxt("mexopencv:error",
176  "Failed to load dataset '%s'", filename.c_str());
177  if (splitCount >= 0)
178  p->setTrainTestSplit(splitCount, splitShuffle);
179  else if (splitRatio >= 0)
180  p->setTrainTestSplitRatio(splitRatio, splitShuffle);
181  return p;
182 
183 }
const ConstMap< string, int > VariableTypeMap
Option values for variable types.
void set(mwIndex index, const T &value)
Template for numeric array element write accessor.
Definition: MxArray.hpp:1310
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
Definition: MxArray.cpp:489
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
Definition: MxArray.hpp:123
void nargchk(bool cond)
Alias for input/ouput arguments number check.
Definition: mexopencv.hpp:166
static MxArray Struct(const char **fields=NULL, int nfields=0, mwSize m=1, mwSize n=1)
Create a new struct array.
Definition: MxArray.hpp:312
const ConstMap< string, int > SampleTypesMap
Option values for sample layouts.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Definition: MxArray.hpp:927
Common definitions for the ml module.