Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
Main Page
Related Pages
Modules
Namespaces
Classes
Files
File List
File Members
obj-x86_64-linux-gnu
examples
Supervised
VersatileClassificationTutorial-NN.cpp
Go to the documentation of this file.
1
2
#include <
shark/Data/Dataset.h
>
3
#include <
shark/Data/Csv.h
>
4
#include <
shark/ObjectiveFunctions/Loss/ZeroOneLoss.h
>
5
6
#include <
shark/Models/Trees/KDTree.h
>
7
#include <
shark/Models/NearestNeighborClassifier.h
>
8
#include <
shark/Algorithms/NearestNeighbors/TreeNearestNeighbors.h
>
9
10
11
using namespace
shark
;
12
13
int
main
()
14
{
15
// Load data, use 70% for training and 30% for testing.
16
// The path is hard coded; make sure to invoke the executable
17
// from a place where the data file can be found. It is located
18
// under [shark]/examples/Supervised/data.
19
ClassificationDataset
traindata, testdata;
20
importCSV
(traindata,
"data/quickstartData.csv"
,
LAST_COLUMN
,
' '
);
21
testdata =
splitAtElement
(traindata, 70 * traindata.
numberOfElements
() / 100);
22
23
unsigned
int
k = 3;
// number of neighbors
24
KDTree<RealVector>
tree(traindata.
inputs
());
25
TreeNearestNeighbors<RealVector, unsigned int>
algorithm(traindata, &tree);
26
NearestNeighborClassifier<RealVector>
model(&algorithm, k);
27
28
Data<unsigned int>
prediction = model(testdata.
inputs
());
29
30
ZeroOneLoss<unsigned int>
loss;
31
double
error_rate = loss(testdata.
labels
(), prediction);
32
33
std::cout <<
"model: "
<< model.
name
() << std::endl
34
<<
"test error rate: "
<< error_rate << std::endl;
35
}