15.2 k 최근접 이웃
15.2.1 k 최근접 이웃 알고리즘
kNN 알고리즘 : 분류 또는 회귀에 사용되는 지도 학습 알고리즘의 하나
분류 - 특징 공간에서 테스트 데이터와 가장 가까운 k개의 훈련 데이터를 찾고, k개의 훈련 데이터 중에서 가장 많은 클래스를 테스트 데이터의 클래스로 지정
회귀 - 테스트 데이터에 인접한 k개의 훈련 데이터 평균을 테스트 데이터 값으로 설정
k를 1로 설정하면 최근접 이웃 알고리즘이 됨. 그러므로 보통 k는 1보다 큰 값으로 설정하며, k 값을 어떻게 설정하느냐에 따라 분류 및 회귀 결과가 달라질 수 있음.
최선의 k 값을 결정하는 것은 주어진 데이터에 의존적이며, 보통 k 값이 커질수록 잡음 또는 이상치 데이터의 영향이 감소함. 그러나 k 값이 어느 정도 이상으로 커질 경우 오히려 분류 및 회귀 성능이 떨어질 수 있음.
15.2.2 KNearest 클래스 사용하기
KNearest 클래스 : ml모듈에 포함되어 있음.
KNearest::create() : 단순히 비어있는 KNearest 객체를 생성하여 Ptr<KNearest> 타입으로 변환
static Ptr<KNearest> KNearest::create();
기본적으로 k 값을 10으로 설정. 이 값을 변경하려면 KNearest::setDefaultK() 사용
virtual void KNearest::setDefaultK(int val);
KNearest 객체는 기본적으로 분류를 위한 용도로 생성됨. 회귀에 적용하려면 KNearest::setIsClassifier() 멤버 함수에 false 를 지정하여 호출해야 함
virtual void KNearest::setIsClassifier(bool val);
// val : true이면 분류, false이면 회귀
KNearest 클래스에서 훈련 데이터를 학습한 후, 테스트 데이터에 대한 예측을 수행할 때에는 주로 KNearest::findNearest() 사용
예측 결과와 관련된 정보를 더 많이 반환하기 때문에 유용
virtual float KNearest::findNearest(InputArray samples, int k, OutputArray results,
OutputArray neighborResponses = noArray(), OutputArray dist = noArray()) const;
samples : 행렬의 행 개수는 예측할 테스트 데이터 개수와 같고, 열 개수는 학습 시 사용한 훈련 데이터의 차원과 같아야 함
results : 분류 및 회귀 결과가 저장되는 행렬. samples 행렬과 같은 행 개수를 가지고, 열 개수는 항상 1. 즉, samples 행렬에서 i번째 행에 대한 응답이 results 행렬의 i번째 행에 저장됨
// kNN 알고리즘을 이용한 2차원 점 분류
Mat img;
Mat train, label;
Ptr<KNearest> knn;
int k_value = 1;
void on_k_changed(int, void*);
void addPoint(const Point& pt, int cls);
void trainAndDisplay();
int main(void)
{
img = Mat::zeros(Size(500, 500), CV_8UC3);
knn = KNearest::create();
namedWindow("knn");
const int NUM = 30;
Mat rn(NUM, 2, CV_32SC1);
randn(rn, 0, 50);
for (int i = 0; i < NUM; i++)
{
addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0);
}
randn(rn, 0, 50);
for (int i = 0; i < NUM; i++)
{
addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1);
}
randn(rn, 0, 70);
for (int i = 0; i < NUM; i++)
{
addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2);
}
trainAndDisplay();
createTrackbar("k", "knn", &k_value, 5, on_k_changed);
waitKey();
return 0;
}
void on_k_changed(int, void*)
{
if (k_value < 1) k_value = 1;
trainAndDisplay();
}
void addPoint(const Point& pt, int cls)
{
Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y);
train.push_back(new_sample);
Mat new_label = (Mat_<int>(1, 1) << cls);
label.push_back(new_label);
}
void trainAndDisplay()
{
knn->train(train, ROW_SAMPLE, label);
for (int i = 0; i < img.rows; ++i)
{
for (int j = 0; j < img.cols; ++j)
{
Mat sample = (Mat_<float>(1, 2) << j, i);
Mat res;
knn->findNearest(sample, k_value, res);
int response = cvRound(res.at<float>(0, 0));
if (response == 0)
img.at<Vec3b>(i, j) = Vec3b(128, 128, 255); //R
else if(response == 1)
img.at<Vec3b>(i, j) = Vec3b(128, 255, 128); // G
else if (response == 2)
img.at<Vec3b>(i, j) = Vec3b(255, 128, 128); // B
}
}
for (int i = 0; i < train.rows; i++)
{
int x = cvRound(train.at<float>(i, 0));
int y = cvRound(train.at<float>(i, 1));
int l = label.at<int>(i, 0);
if (l == 0)
circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
else if (l == 1)
circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
else if (l == 2)
circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA);
}
imshow("knn", img);
}
15.2.3 kNN을 이용한 필기체 숫자 인식
머신 러닝으로 특정 문제를 해결하려면 많은 양의 훈련 데이터가 필요.
OPENCV-SRC\samples\data\digits.png 0부터 9까지의 필기체 숫자가 5000개 적혀 있는 영상
위 영상을 KNearest 클래스를 이용하여 필기체 숫자 인식 프로그램을 구현
// KNeartest 클래스를 이용한 필기체 숫자 인식
int main(void)
{
Ptr<KNearest> knn = train_knn();
if (knn.empty())
{
return -1;
}
Mat img = Mat::zeros(400, 400, CV_8U);
imshow("img", img);
setMouseCallback("img", on_mouse, (void*)&img);
while (true)
{
int c = waitKey();
if (c == 27)
{
break;
}
else if (c == ' ')
{
Mat img_resize, img_float, img_flatten, res;
resize(img, img_resize, Size(20, 20), 0, 0, INTER_AREA);
img_resize.convertTo(img_float, CV_32F);
img_flatten = img_float.reshape(1, 1);
knn->findNearest(img_flatten, 3, res);
cout << cvRound(res.at<float>(0, 0)) << endl;
img.setTo(0);
imshow("img", img);
}
}
return 0;
}
Ptr<KNearest> train_knn()
{
Mat digits = imread("digits.png", IMREAD_GRAYSCALE);
if (digits.empty())
{
return 0;
}
Mat train_images, train_labels;
for (int j = 0; j < 50 ; j++)
{
for (int i = 0; i < 100; i++)
{
Mat roi, roi_float, roi_flatten;
roi = digits(Rect(i * 20, j * 20, 20, 20));
roi.convertTo(roi_float, CV_32F);
roi_flatten = roi_float.reshape(1, 1);
train_images.push_back(roi_flatten);
train_labels.push_back(j / 5);
}
}
Ptr<KNearest> knn = KNearest::create();
knn->train(train_images, ROW_SAMPLE, train_labels);
return knn;
}
Point ptPrev(-1, -1);
void on_mouse(int event, int x, int y, int flags, void* userdata)
{
Mat img = *(Mat*)userdata;
if (event == EVENT_LBUTTONDOWN)
{
ptPrev = Point(x, y);
}
else if (event == EVENT_LBUTTONUP)
{
ptPrev = Point(-1, -1);
}
else if (event == EVENT_MOUSEMOVE && (flags & EVENT_FLAG_LBUTTON))
{
line(img, ptPrev, Point(x, y), Scalar::all(255), 40, LINE_AA, 0);
ptPrev = Point(x, y);
imshow("img", img);
}
}
#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 23 (0) | 2022.09.07 |
---|---|
#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 21 (0) | 2022.08.19 |
#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 20 (0) | 2022.08.03 |
#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 19 (0) | 2022.08.02 |
#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 18 (0) | 2022.07.27 |