취미/Programming Problems

HackerRank - Merge Sort: Counting Inversions

Lero God 2023. 11. 8. 00:55

머지 소트 구현!

이전부터 공부해보고 싶었던 알고리즘인데 마침 해커랭크 문제가 있어서 구현해봤습니다.

어려워 보일 수도 있으나 원리만 정확히 이해하면 누구나 충분히 구현 할 수 있을 것 같습니다.

소트나 검색 알고리즘에서 항상 트리를 이용하기 때문에 트리 검색이나 소트 할 때 가장 최적화된 방식으로 하는 걸 공부하고 고민하면서 머지 소트에 대한 이해도를 높이는 것도 도움이 될 것 같습니다. 

다음엔 대부분의 라이브러리에서 소트된 키-밸류 컨테이너를 구현할 때 사용하는 레드 블랙 트리를 직접 구현해보고 싶네요 👻

struct MergeArray
{
	vector<int> left;
	vector<int> right;
};

shared_ptr<MergeArray> Split(const vector<int>& arr)
{
	MergeArray mergeArray;

	int size = arr.size();
	int mid = size / 2;

	auto leftArrEndIter = arr.begin() + mid;
	mergeArray.left.insert(mergeArray.left.end(), arr.begin(), leftArrEndIter);
	auto midIter = arr.begin() + mid;
	mergeArray.right.insert(mergeArray.right.end(), midIter, arr.end());

	return make_shared<MergeArray>(mergeArray);
}

vector<int> Merge(shared_ptr<MergeArray>& mergeArray, long& inversionCount)
{
	if (mergeArray->left.size() >= 2)
	{
		auto mergeArr = Split(mergeArray->left);
		mergeArray->left = Merge(mergeArr, inversionCount);
	}

	if (mergeArray->right.size() >= 2)
	{
		auto mergeArr = Split(mergeArray->right);
		mergeArray->right = Merge(mergeArr, inversionCount);
	}

	vector<int> mergedArr;
	auto leftBegin = mergeArray->left.begin();
	auto rightBegin = mergeArray->right.begin();
	int leftLeftCount = mergeArray->left.size();

	while (leftBegin != mergeArray->left.end() || rightBegin != mergeArray->right.end())
	{
		if (leftBegin == mergeArray->left.end())
		{
			while (rightBegin != mergeArray->right.end())
			{
				mergedArr.emplace_back(*rightBegin);
				++rightBegin;
			}
			break;
		}

		if (rightBegin == mergeArray->right.end())
		{
			while (leftBegin != mergeArray->left.end())
			{
				mergedArr.emplace_back(*leftBegin);
				++leftBegin;
			}
			break;
		}

		auto leftVal = *leftBegin;
		auto rightVal = *rightBegin;

		if (leftVal > rightVal)
		{
			mergedArr.emplace_back(rightVal);
			++rightBegin;
			inversionCount += leftLeftCount;
		}
		else
		{
			mergedArr.emplace_back(leftVal);
			++leftBegin;
			--leftLeftCount;
		}
	}

	return mergedArr;
}

long countInversions(vector<int> arr)
{
	if (arr.size() < 2)
	{
		return 0;
	}


	long inversionCount = 0;
	auto mergeArray = Split(arr);
	auto mergedArr = Merge(mergeArray, inversionCount);

	return inversionCount;
}

int main()
{
	string t_temp;
	getline(cin, t_temp);

	int t = stoi(ltrim(rtrim(t_temp)));

	for (int t_itr = 0; t_itr < t; t_itr++) {
		string n_temp;
		getline(cin, n_temp);

		int n = stoi(ltrim(rtrim(n_temp)));

		string arr_temp_temp;
		getline(cin, arr_temp_temp);

		vector<string> arr_temp = split(rtrim(arr_temp_temp));

		vector<int> arr(n);

		for (int i = 0; i < n; i++) {
			int arr_item = stoi(arr_temp[i]);

			arr[i] = arr_item;
		}

		long result = countInversions(arr);

		cout << result << "\n";
	}

	return 0;
}