상세 컨텐츠

본문 제목

#OpenCV 4로 배우는 컴퓨터 비전과 머신 러닝 - 22

Programing/OpenCV

by CouqueD'asse 2022. 9. 2. 16:28

본문

15.2 k 최근접 이웃

 

15.2.1 k 최근접 이웃 알고리즘

kNN 알고리즘 : 분류 또는 회귀에 사용되는 지도 학습 알고리즘의 하나

분류 - 특징 공간에서 테스트 데이터와 가장 가까운 k개의 훈련 데이터를 찾고, k개의 훈련 데이터 중에서 가장 많은 클래스를 테스트 데이터의 클래스로 지정

회귀 - 테스트 데이터에 인접한 k개의 훈련 데이터 평균을 테스트 데이터 값으로 설정

k를 1로 설정하면 최근접 이웃 알고리즘이 됨. 그러므로 보통 k는 1보다 큰 값으로 설정하며, k 값을 어떻게 설정하느냐에 따라 분류 및 회귀 결과가 달라질 수 있음.

최선의 k 값을 결정하는 것은 주어진 데이터에 의존적이며, 보통 k 값이 커질수록 잡음 또는 이상치 데이터의 영향이 감소함. 그러나 k 값이 어느 정도 이상으로 커질 경우 오히려 분류 및 회귀 성능이 떨어질 수 있음.

 

15.2.2 KNearest 클래스 사용하기

KNearest 클래스 : ml모듈에 포함되어 있음.

KNearest::create() : 단순히 비어있는 KNearest 객체를 생성하여 Ptr<KNearest> 타입으로 변환

static Ptr<KNearest> KNearest::create();

기본적으로 k 값을 10으로 설정. 이 값을 변경하려면 KNearest::setDefaultK() 사용

virtual void KNearest::setDefaultK(int val);

KNearest 객체는 기본적으로 분류를 위한 용도로 생성됨. 회귀에 적용하려면 KNearest::setIsClassifier() 멤버 함수에 false 를 지정하여 호출해야 함

virtual void KNearest::setIsClassifier(bool val);
// val : true이면 분류, false이면 회귀

KNearest 클래스에서 훈련 데이터를 학습한 후, 테스트 데이터에 대한 예측을 수행할 때에는 주로 KNearest::findNearest() 사용

예측 결과와 관련된 정보를 더 많이 반환하기 때문에 유용

virtual float KNearest::findNearest(InputArray samples, int k, OutputArray results,
	OutputArray neighborResponses = noArray(), OutputArray dist = noArray()) const;

samples : 행렬의 행 개수는 예측할 테스트 데이터 개수와 같고, 열 개수는 학습 시 사용한 훈련 데이터의 차원과 같아야 함

results : 분류 및 회귀 결과가 저장되는 행렬. samples 행렬과 같은 행 개수를 가지고, 열 개수는 항상 1. 즉, samples 행렬에서 i번째 행에 대한 응답이 results 행렬의 i번째 행에 저장됨

// kNN 알고리즘을 이용한 2차원 점 분류
Mat img;
Mat train, label;
Ptr<KNearest> knn;
int k_value = 1;

void on_k_changed(int, void*);
void addPoint(const Point& pt, int cls);
void trainAndDisplay();

int main(void)
{
	img = Mat::zeros(Size(500, 500), CV_8UC3);
	knn = KNearest::create();

	namedWindow("knn");

	const int NUM = 30;
	Mat rn(NUM, 2, CV_32SC1);

	randn(rn, 0, 50);
	for (int i = 0; i < NUM; i++)
	{
		addPoint(Point(rn.at<int>(i, 0) + 150, rn.at<int>(i, 1) + 150), 0);
	}

	randn(rn, 0, 50);
	for (int i = 0; i < NUM; i++)
	{
		addPoint(Point(rn.at<int>(i, 0) + 350, rn.at<int>(i, 1) + 150), 1);
	}

	randn(rn, 0, 70);
	for (int i = 0; i < NUM; i++)
	{
		addPoint(Point(rn.at<int>(i, 0) + 250, rn.at<int>(i, 1) + 400), 2);
	}

	trainAndDisplay();
	
	createTrackbar("k", "knn", &k_value, 5, on_k_changed);

	waitKey();
	return 0;
}

void on_k_changed(int, void*)
{
	if (k_value < 1) k_value = 1;
	trainAndDisplay();
}

void addPoint(const Point& pt, int cls)
{
	Mat new_sample = (Mat_<float>(1, 2) << pt.x, pt.y);
	train.push_back(new_sample);

	Mat new_label = (Mat_<int>(1, 1) << cls);
	label.push_back(new_label);
}

void trainAndDisplay()
{
	knn->train(train, ROW_SAMPLE, label);

	for (int i = 0; i < img.rows; ++i)
	{
		for (int j = 0; j < img.cols; ++j)
		{
			Mat sample = (Mat_<float>(1, 2) << j, i);

			Mat res;
			knn->findNearest(sample, k_value, res);

			int response = cvRound(res.at<float>(0, 0));
			if (response == 0)
				img.at<Vec3b>(i, j) = Vec3b(128, 128, 255); //R
			else if(response == 1)
				img.at<Vec3b>(i, j) = Vec3b(128, 255, 128); // G
			else if (response == 2)
				img.at<Vec3b>(i, j) = Vec3b(255, 128, 128); // B
		}
	}

	for (int i = 0; i < train.rows; i++)
	{
		int x = cvRound(train.at<float>(i, 0));
		int y = cvRound(train.at<float>(i, 1));
		int l = label.at<int>(i, 0);

		if (l == 0)
			circle(img, Point(x, y), 5, Scalar(0, 0, 128), -1, LINE_AA);
		else if (l == 1)
			circle(img, Point(x, y), 5, Scalar(0, 128, 0), -1, LINE_AA);
		else if (l == 2)
			circle(img, Point(x, y), 5, Scalar(128, 0, 0), -1, LINE_AA);
	}

	imshow("knn", img);
}

 

15.2.3 kNN을 이용한 필기체 숫자 인식

머신 러닝으로 특정 문제를 해결하려면 많은 양의 훈련 데이터가 필요.

OPENCV-SRC\samples\data\digits.png 0부터 9까지의 필기체 숫자가 5000개 적혀 있는 영상

위 영상을 KNearest 클래스를 이용하여 필기체 숫자 인식 프로그램을 구현

// KNeartest 클래스를 이용한 필기체 숫자 인식
int main(void)
{
	Ptr<KNearest> knn = train_knn();

	if (knn.empty())
	{
		return -1;
	}

	Mat img = Mat::zeros(400, 400, CV_8U);

	imshow("img", img);
	setMouseCallback("img", on_mouse, (void*)&img);

	while (true)
	{
		int c = waitKey();

		if (c == 27)
		{
			break;
		}
		else if (c == ' ')
		{
			Mat img_resize, img_float, img_flatten, res;

			resize(img, img_resize, Size(20, 20), 0, 0, INTER_AREA);
			img_resize.convertTo(img_float, CV_32F);
			img_flatten = img_float.reshape(1, 1);

			knn->findNearest(img_flatten, 3, res);
			cout << cvRound(res.at<float>(0, 0)) << endl;

			img.setTo(0);
			imshow("img", img);

		}
	}


	return 0;
}

Ptr<KNearest> train_knn()
{
	Mat digits = imread("digits.png", IMREAD_GRAYSCALE);

	if (digits.empty())
	{
		return 0;
	}

	Mat train_images, train_labels;

	for (int j = 0; j < 50 ; j++)
	{
		for (int i = 0; i < 100; i++)
		{
			Mat roi, roi_float, roi_flatten;
			roi = digits(Rect(i * 20, j * 20, 20, 20));
			roi.convertTo(roi_float, CV_32F);
			roi_flatten = roi_float.reshape(1, 1);

			train_images.push_back(roi_flatten);
			train_labels.push_back(j / 5);
		}
	}

	Ptr<KNearest> knn = KNearest::create();
	knn->train(train_images, ROW_SAMPLE, train_labels);

	return knn;
}

Point ptPrev(-1, -1);

void on_mouse(int event, int x, int y, int flags, void* userdata)
{
	Mat img = *(Mat*)userdata;

	if (event == EVENT_LBUTTONDOWN)
	{
		ptPrev = Point(x, y);
	}
	else if (event == EVENT_LBUTTONUP)
	{
		ptPrev = Point(-1, -1);
	}
	else if (event == EVENT_MOUSEMOVE && (flags & EVENT_FLAG_LBUTTON))
	{
		line(img, ptPrev, Point(x, y), Scalar::all(255), 40, LINE_AA, 0);
		ptPrev = Point(x, y);

		imshow("img", img);
	}
}

관련글 더보기