19 map<int,Ptr<KNearest> > obj_;
23 (
"BruteForce", KNearest::BRUTE_FORCE)
24 (
"KDTree", KNearest::KDTREE);
28 (KNearest::BRUTE_FORCE,
"BruteForce")
29 (KNearest::KDTREE,
"KDTree");
39 void mexFunction(
int nlhs, mxArray *plhs[],
int nrhs,
const mxArray *prhs[])
45 vector<MxArray> rhs(prhs, prhs+nrhs);
46 int id = rhs[0].toInt();
47 string method(rhs[1].toString());
50 if (method ==
"new") {
52 obj_[++last_id] = KNearest::create();
58 Ptr<KNearest> obj = obj_[id];
59 if (method ==
"delete") {
63 else if (method ==
"clear") {
67 else if (method ==
"load") {
68 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs==0);
70 bool loadFromString =
false;
71 for (
int i=3; i<nrhs; i+=2) {
72 string key(rhs[i].toString());
74 objname = rhs[i+1].toString();
75 else if (key ==
"FromString")
76 loadFromString = rhs[i+1].toBool();
78 mexErrMsgIdAndTxt(
"mexopencv:error",
79 "Unrecognized option %s", key.c_str());
81 obj_[id] = (loadFromString ?
82 Algorithm::loadFromString<KNearest>(rhs[2].toString(), objname) :
83 Algorithm::load<KNearest>(rhs[2].toString(), objname));
85 else if (method ==
"save") {
87 string fname(rhs[2].toString());
90 FileStorage fs(fname, FileStorage::WRITE + FileStorage::MEMORY);
91 fs << obj->getDefaultName() <<
"{";
95 plhs[0] =
MxArray(fs.releaseAndGetString());
101 else if (method ==
"empty") {
103 plhs[0] =
MxArray(obj->empty());
105 else if (method ==
"getDefaultName") {
107 plhs[0] =
MxArray(obj->getDefaultName());
109 else if (method ==
"getVarCount") {
111 plhs[0] =
MxArray(obj->getVarCount());
113 else if (method ==
"isClassifier") {
115 plhs[0] =
MxArray(obj->isClassifier());
117 else if (method ==
"isTrained") {
119 plhs[0] =
MxArray(obj->isTrained());
121 else if (method ==
"train") {
122 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=1);
123 vector<MxArray> dataOptions;
125 for (
int i=4; i<nrhs; i+=2) {
126 string key(rhs[i].toString());
128 dataOptions = rhs[i+1].toVector<
MxArray>();
129 else if (key ==
"Flags")
130 flags = rhs[i+1].
toInt();
131 else if (key ==
"UpdateModel")
132 UPDATE_FLAG(flags, rhs[i+1].toBool(), StatModel::UPDATE_MODEL);
134 mexErrMsgIdAndTxt(
"mexopencv:error",
135 "Unrecognized option %s", key.c_str());
140 dataOptions.begin(), dataOptions.end());
143 rhs[2].toMat(CV_32F),
144 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
145 dataOptions.begin(), dataOptions.end());
146 bool b = obj->train(data, flags);
149 else if (method ==
"calcError") {
150 nargchk(nrhs>=4 && (nrhs%2)==0 && nlhs<=2);
151 vector<MxArray> dataOptions;
153 for (
int i=4; i<nrhs; i+=2) {
154 string key(rhs[i].toString());
156 dataOptions = rhs[i+1].toVector<
MxArray>();
157 else if (key ==
"TestError")
160 mexErrMsgIdAndTxt(
"mexopencv:error",
161 "Unrecognized option %s", key.c_str());
166 dataOptions.begin(), dataOptions.end());
169 rhs[2].toMat(CV_32F),
170 rhs[3].toMat(rhs[3].isInt32() ? CV_32S : CV_32F),
171 dataOptions.begin(), dataOptions.end());
173 float err = obj->calcError(data, test, (nlhs>1 ? resp : noArray()));
178 else if (method ==
"predict") {
179 nargchk(nrhs>=3 && (nrhs%2)==1 && nlhs<=2);
181 for (
int i=3; i<nrhs; i+=2) {
182 string key(rhs[i].toString());
184 flags = rhs[i+1].toInt();
186 mexErrMsgIdAndTxt(
"mexopencv:error",
187 "Unrecognized option %s", key.c_str());
189 Mat samples(rhs[2].toMat(CV_32F)),
191 float f = obj->predict(samples, results, flags);
196 else if (method ==
"findNearest") {
198 Mat samples(rhs[2].toMat(CV_32F));
199 int k = rhs[3].toInt();
200 Mat results, neighborResponses, dist;
201 float f = obj->findNearest(samples, k, results, neighborResponses, dist);
204 plhs[1] =
MxArray(neighborResponses);
210 else if (method ==
"get") {
212 string prop(rhs[2].toString());
213 if (prop ==
"AlgorithmType")
214 plhs[0] =
MxArray(InvKNNAlgType[obj->getAlgorithmType()]);
215 else if (prop ==
"DefaultK")
216 plhs[0] =
MxArray(obj->getDefaultK());
217 else if (prop ==
"Emax")
218 plhs[0] =
MxArray(obj->getEmax());
219 else if (prop ==
"IsClassifier")
220 plhs[0] =
MxArray(obj->getIsClassifier());
222 mexErrMsgIdAndTxt(
"mexopencv:error",
223 "Unrecognized property %s", prop.c_str());
225 else if (method ==
"set") {
227 string prop(rhs[2].toString());
228 if (prop ==
"AlgorithmType")
229 obj->setAlgorithmType(KNNAlgType[rhs[3].toString()]);
230 else if (prop ==
"DefaultK")
231 obj->setDefaultK(rhs[3].toInt());
232 else if (prop ==
"Emax")
233 obj->setEmax(rhs[3].toInt());
234 else if (prop ==
"IsClassifier")
235 obj->setIsClassifier(rhs[3].toBool());
237 mexErrMsgIdAndTxt(
"mexopencv:error",
238 "Unrecognized property %s", prop.c_str());
241 mexErrMsgIdAndTxt(
"mexopencv:error",
"Unrecognized operation");
#define UPDATE_FLAG(NUM, TF, BIT)
set or clear a bit in flag depending on bool value
int toInt() const
Convert MxArray to int.
cv::Ptr< cv::ml::TrainData > loadTrainData(const std::string &filename, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Read a dataset from a CSV file.
mxArray object wrapper for data conversion and manipulation.
void nargchk(bool cond)
Alias for input/ouput arguments number check.
bool toBool() const
Convert MxArray to bool.
Global constant definitions.
cv::Ptr< cv::ml::TrainData > createTrainData(const cv::Mat &samples, const cv::Mat &responses, std::vector< MxArray >::const_iterator first, std::vector< MxArray >::const_iterator last)
Create an instance of TrainData using options in arguments.
std::map wrapper with one-line initialization and lookup method.
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Main entry called from Matlab.
Common definitions for the ml module.