18 (
"Row", cv::ml::ROW_SAMPLE)
19 (
"Col", cv::ml::COL_SAMPLE);
23 (
"Numerical", cv::ml::VAR_NUMERICAL)
24 (
"Ordered", cv::ml::VAR_ORDERED)
25 (
"Categorical", cv::ml::VAR_CATEGORICAL)
26 (
"N", cv::ml::VAR_NUMERICAL)
27 (
"O", cv::ml::VAR_ORDERED)
28 (
"C", cv::ml::VAR_CATEGORICAL);
35 const char* fields[] = {
"value",
"classIdx",
"parent",
"left",
"right",
36 "defaultDir",
"split"};
38 for (
size_t i=0; i<nodes.size(); ++i) {
39 s.
set(
"value", nodes[i].value, i);
40 s.
set(
"classIdx", nodes[i].classIdx, i);
41 s.
set(
"parent", nodes[i].parent, i);
42 s.
set(
"left", nodes[i].left, i);
43 s.
set(
"right", nodes[i].right, i);
44 s.
set(
"defaultDir", nodes[i].defaultDir, i);
45 s.
set(
"split", nodes[i].split, i);
52 const char* fields[] = {
"varIdx",
"inversed",
"quality",
"next",
"c",
55 for (
size_t i=0; i<splits.size(); ++i) {
56 s.
set(
"varIdx", splits[i].varIdx, i);
57 s.
set(
"inversed", splits[i].inversed, i);
58 s.
set(
"quality", splits[i].quality, i);
59 s.
set(
"next", splits[i].next, i);
60 s.
set(
"c", splits[i].c, i);
61 s.
set(
"subsetOfs", splits[i].subsetOfs, i);
70 const Mat& samples,
const Mat& responses,
71 vector<MxArray>::const_iterator first,
72 vector<MxArray>::const_iterator last)
74 nargchk((std::distance(first, last) % 2) == 0);
75 int layout = cv::ml::ROW_SAMPLE;
76 Mat varIdx, sampleIdx, sampleWeights, varType;
79 double splitRatio = -1.0;
80 bool splitShuffle =
true;
81 for (; first != last; first += 2) {
82 string key(first->toString());
83 const MxArray& val = *(first + 1);
85 layout = SampleTypesMap[val.toString()];
86 else if (key ==
"VarIdx")
88 (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
89 else if (key ==
"SampleIdx")
90 sampleIdx = val.toMat(
91 (val.isUint8() || val.isLogical()) ? CV_8U : CV_32S);
92 else if (key ==
"SampleWeights")
93 sampleWeights = val.toMat(CV_32F);
94 else if (key ==
"VarType") {
96 vector<string> vtypes(val.toVector<
string>());
97 varType.create(1, vtypes.size(), CV_8U);
98 for (
size_t idx = 0; idx < vtypes.size(); idx++)
99 varType.at<uchar>(idx) = VariableTypeMap[vtypes[idx]];
101 else if (val.isChar()) {
102 string str(val.toString());
103 varType.create(1, str.size(), CV_8U);
104 for (
size_t idx = 0; idx < str.size(); idx++)
105 varType.at<uchar>(idx) = VariableTypeMap[string(1,str[idx])];
107 else if (val.isNumeric())
108 varType = val.toMat(CV_8U);
110 mexErrMsgIdAndTxt(
"mexopencv:error",
"Invalid VarType value");
112 else if (key ==
"MissingMask")
113 missing = val.toMat(CV_8U);
114 else if (key ==
"TrainTestSplitCount")
115 splitCount = val.toInt();
116 else if (key ==
"TrainTestSplitRatio")
117 splitRatio = val.toDouble();
118 else if (key ==
"TrainTestSplitShuffle")
119 splitShuffle = val.toBool();
121 mexErrMsgIdAndTxt(
"mexopencv:error",
122 "Unrecognized option %s", key.c_str());
124 Ptr<TrainData> p = TrainData::create(samples, layout, responses,
125 varIdx, sampleIdx, sampleWeights, varType);
127 p->setTrainTestSplit(splitCount, splitShuffle);
128 else if (splitRatio >= 0)
129 p->setTrainTestSplitRatio(splitRatio, splitShuffle);
134 vector<MxArray>::const_iterator first,
135 vector<MxArray>::const_iterator last)
137 nargchk((std::distance(first, last) % 2) == 0);
138 int headerLineCount = 1;
139 int responseStartIdx = -1;
140 int responseEndIdx = -1;
142 char delimiter =
',';
145 double splitRatio = -1.0;
146 bool splitShuffle =
true;
147 for (; first != last; first += 2) {
148 string key(first->toString());
149 const MxArray& val = *(first + 1);
150 if (key ==
"HeaderLineCount")
151 headerLineCount = val.
toInt();
152 else if (key ==
"ResponseStartIdx")
153 responseStartIdx = val.toInt();
154 else if (key ==
"ResponseEndIdx")
155 responseEndIdx = val.toInt();
156 else if (key ==
"VarTypeSpec")
157 varTypeSpec = val.toString();
158 else if (key ==
"Delimiter")
159 delimiter = (!val.isEmpty()) ? val.toString()[0] :
' ';
160 else if (key ==
"Missing")
161 missch = (!val.isEmpty()) ? val.toString()[0] :
'?';
162 else if (key ==
"TrainTestSplitCount")
163 splitCount = val.toInt();
164 else if (key ==
"TrainTestSplitRatio")
165 splitRatio = val.toDouble();
166 else if (key ==
"TrainTestSplitShuffle")
167 splitShuffle = val.toBool();
169 mexErrMsgIdAndTxt(
"mexopencv:error",
170 "Unrecognized option %s", key.c_str());
172 Ptr<TrainData> p = TrainData::loadFromCSV(filename, headerLineCount,
173 responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch);
175 mexErrMsgIdAndTxt(
"mexopencv:error",
176 "Failed to load dataset '%s'", filename.c_str());
178 p->setTrainTestSplit(splitCount, splitShuffle);
179 else if (splitRatio >= 0)
180 p->setTrainTestSplitRatio(splitRatio, splitShuffle);
const ConstMap< string, int > VariableTypeMap
Option values for variable types.
void set(mwIndex index, const T &value)
Template for numeric array element write accessor.
MxArray toStruct(const std::vector< cv::ml::DTrees::Node > &nodes)
Convert tree nodes to struct array.
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
static MxArray Struct(const char **fields=NULL, int nfields=0, mwSize m=1, mwSize n=1)
Create a new struct array.
const ConstMap< string, int > SampleTypesMap
Option values for sample layouts.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
Common definitions for the ml module.