Published on

회고록-알고리즘-주사위던지기

Authors
  • avatar
    Name
    Jihwan Seong
    Twitter

배경

본론

문제 분석

  • 문제가 요구하는것은 "가장 승률이 높은 주사위 조합"입니다.
  • A와 B가 복수의 주사위를 던졌을 때의 합의 모든 경우의 수를 구해야합니다.
    • 가장 승률이 높은 주사위 조합을 확인하려면, 모든 경우의 수를 확인하여 승리 횟수를 알아내는 방법 밖에 없습니다.
    • 어떻게 하면, 효율적으로 주사위 조합에 따른, A와 B의 승률을 구할 수 있을까요?
  • 주사위의 개수 N이 변화함에 따라, 주사위 던지는 결과의 경우의 수가 달라집니다.
    • 주사위 개수가 4개면, 4C2 _ 6^2 _ 6^2
    • 주사위 개수가 6개면, 6C3 _ 6^3 _ 6^3
    • 주사위 개수가 2a개면, 2aCa _ 6^a _ 6^a
    • 어떻게하면, N이 변화함에도 모든 주사위 조합을 확인할 수 있을 까요?

사유하기

  • 제가 알고 있는 알고리즘 중에서, 문제를 한번에 해결해줄 도구를 찾지는 못했습니다.

    • 여기서 말하는 알고리즘 기본적인 알고리즘인 DFS,BFS, DP, 완전 탐색 등입니다.
  • 가장 먼저 해결해야할 것은 모든 N에 대해서 경우의 수를 어떻게 표현할 것인가 였습니다.

    • 주어진 테스트케이스를 손으로 스케칭하며 문제 풀이방법을 고민하였습니다.
  • 맨처음 고안 한 방법은, 배열 A와 B를 만들고, 각각의 주사위에 매칭되는 숫자를 각각의 배열에 삽입하는 방법이었습니다.

    • N= 4이면, 총 6가지입니다.
      • A = [1,2] ,B= [3,4]
      • A = [2,3] ,B= [1,4]
      • A = [3,4] ,B= [1,2]
      • ...
    • 이 방법의 문제는 어떻게 A와 B배열에 각 주사위 넘버를 삽입할 순서를 정할 것인가 였습니다.
    • 어떻게든 할수는 있을것 같았지만, 제가 선호하는 쉬운 방법이아니라서 일단 보류하였습니다.
      • 나중에 다른 분들 코드를 보니, 재귀함수를 사용하여 완전 탐색을 하거나, 순열 또는 조합 STD 라이브러리를 사용하여 경우의 수를 구현하셨더군요..... 그러한 알고리즘이 있다는 것을 완전히 잊었습니다...
  • 그다음에 고안한 방법은 각 경우의 수를 비트맵으로 표현하는 것이었습니다.

    • N= 4이면, 각 경우의 수는 1100 , 1010 , 0110, 1001, 0101,0011 로 표현됩니다.
      • 각 비트 자리는 주사위를 나타냅니다. i번째 비트는 주사위 i+1을 나타냅니다.
      • i번째 비트 값이 1이면, i+1번째 주사위는 A의 것이고 값이 0이면 B의 것임을 나타냅니다.
    • 해당 방법은 배열 방식보다 상대적인 장점이 2개 있습니다.
      1. 디버깅이 편하다.
        • 각 경우의수를 숫자 변수 하나로 표현하므로, 출력하여 상황을 확인하기 쉽습니다.
      2. 구현이 간단하다.
        • 숫자 변수를 1씩 증가시키며, 반복문을 사용하면 쉽게 구현할 수 있습니다.
          • 주사위 개수 N이 변화하여도, 숫자 변수의 범위를 쉽게 구할 수 있습니다.
          • N=4 이면, 숫자 변수 범위는 3(b1100)~12(b0011) 입니다.
          • N=6 이면, 숫자 변수 범위는 7(b111000)~56(b000111) 입니다.
          • N<11이므로, 4byte(32bit) 정수형 타입인 int로 모든 경우의 수를 표현할 수 있습니다.
        • 숫자 변수의 비트값이 N/2인 경우에만 다음 작업을 진행하면 됩니다.
        • 숫자 변수를 통해서 A와 B에게 주어진 주사위 조합을 쉽게 알 수 있습니다.
    • 코드는 다음과 같습니다.
    int pow(int num,int base){
        if(base == 0){
            return 1;
        }
        return num * pow(num,base-1);
    }
    
    
    vector<int> solution(vector<vector<int>> dice) {
        int n = dice.size();
        int start = 0;
        int end =  0;
    
        //initalize
        for(int i = 0; i < n/2; i++){
            start+=pow(2,i);
        }
        for(int i = n/2; i <n; i++ ){
            end+=pow(2,i);
        }
    
        while(start <= end){
            int bitCnt = 0;
            for(int i = 0; i < n; i++){
                int temp = start & (1<<i) ;
                if(temp > 0) {
                    bitCnt++;
                }
            }
    
            if(bitCnt != n/2) {
                start++;
                continue;
            }
    
            //next task
    
            //final
            start++;
        }
    }
    
  • 다음에 해결할 것은 주어진 주사위 조합에 대해서 모든 승부 승률을 계산하는 것 입니다.

  • 맨처음에 고안한 것은 단순하게 모든 경우의수를 계산하는 것이었는데요.

    • 주사위 개수 <=10 이므로, 주사위 조합 1가지에 승부의 경우의 수 최대값은 6^5 _ 6^5 = (A가 5개 주사위 던질 때 합 경우의수) _ (B가 5개 주사위 던질때 합 경우의 수)입니다.
    • 어림 잡아 계산하면, 6^10 = 2^10 _ 3^10 ~= 1000 _ 3^10 ~= 1000* 80^2 * 9
    • 주사위 개수가 10일 때, 주사위 조합은 10C2 = 45 이므로, 최대 (1000*80^2 *9 * 45)번 반복해야할 작업이 생깁니다.
    • 매우 큰 수는 아니지만, 적당히 큰수 이므로 더 좋은 방법을 고려할 필요가 있습니다.
  • 다음에 고안한 방법은 정렬을 사용한 방법입니다.

    • 중요한 것은 주어진 주사위 조합에 대해 플레이어 A와 B 각각에 대한 모든 주사위의 합은 구해야한다는 것입니다.
      • 그래야 A의 주사위 합 경우의 수 1개가 B의 주사위 합 경우의 수를 몇 개를 승리할지 계산할 수 있기 때문입니다.
    • 주어진 주사위 조합에 대해 A와 B의 모든 주사위 합을 오름차순으로 정렬하면, 손쉽게 승리할 경우의 수를 계산할 수 있습니다.
      • 정렬된 상태에서 A[i]가 i번째 주사위 합이고, B[j]가 j번째 주사위 합이라면,
        • if( A[i] > B[j])
          • A[i]로는 최소 j 승리할 수 있습니다.
        • if( A[i] <= B[j])
          • A[i]로는 더이상 승리할 수 없습니다.
        • 이러한 방법으로 모든 b[j]를 탐색하지 않고 A[i]로 승리할 수 있는 횟수를 계산할 수 있습니다.
    • 주사위 개수가 10일때, 주사위 조합 1개에 A와 B의 주사위 합을 구하는 경우의 수는 6^5+6^5 = 2*6^5 입니다.
      • 무조건 해야하는 작업입니다.
    • 모든 i에 대하여 주사위 합 A[i]의 승리 횟수를 구하는 방법은 이제 알고리즘에 따라 달라지게됩니다.
      • 단순하게 모든 경우의 수를 계산하면 고정된 성능이 보장되지만, 이제는 알고리즘에 따라 성능이 달라지게됩니다.
  • 2가지 문제에 대한 실마리를 잡았습니다.

    1. 어떻게 하면, 효율적으로 주사위 조합에 따른, A와 B의 승률을 구할 수 있을까요?
    • 주어진 주사위 조합에 따른, A와 B의 모든 주사위 합 경우의 수를 구한 뒤, 정렬하고 승리 횟수를 계산한다.
    1. 어떻게하면, N이 변화함에도 모든 주사위 조합을 확인할 수 있을 까요?
    • 모든 주사위 조합을 비트맵으로 표현한다.
  • 이제 문제를 풀어봅시다.

문제 풀이

첫번째코드

  • 첫번째 코드는 제가 처음 문제 정답을 맞췄을 때 코드입니다.
    • 주석 처리된 입출력 로직은 디버깅 흔적입니다.
  • 저는 실마리를 잡은뒤, 코드를 구현하였는데요. 이때 여러가지 실수가 있었습니다.
    1. stupid mistake
      • 배열 인덱스 변수를 잘못 전달함.
        • A[i] +=B[i] => A[i] += B[j] 등의 실수
      • 수식을 잘못 입력함
        • A+= A+B => A+=B 등의 실수
      • 연산자 우선순위를 고려하지 않음
        • if(bitCnt & (1<<i) > 0) => if((bitCnt & (1<<i)) > 0) 실수
          • &보다 > 연산잦 순위가 더 높으므로, 원치 않은 결과가 발생합니다.
    2. critical area
      • 이번에는 설계 실수가 없었습니다.
      • 모든 경우의 수를 고려하였고, 문제를 일반화하여 풀이를 만들었습니다.
  • 문제를 푸는데 약 3시간 정도 걸렸습니다. 원인은 다음과 같습니다.
    1. 디버깅에 약 1시간 소요됨.
      • stupid mistake를 잡는데 약 1시간 걸렸습니다.
      • stupid mistake를 어떻게 해결할 수 있을지 고민이 필요합니다.
    2. 설계 시간 약 1시간 소요됨.
      • 제가 가장 고민했던 문제는 어떻게하면, N이 변화함에도 모든 주사위 조합을 확인할 수 있을 까요? 였습니다.
        • 문제 풀이 당시, N에 변화에 따라 forwhile의 중첩 개수도 변화되는 것 때문에 반복문 말고 다른 방법도 같이 고민했었는데요.
        • 그때 당시에는 어떻게도 순열과 조합을 쉽게 표현할 방법을 생각하지 못했었습니다...
          • 순열과 조합 알고리즘 라이브러리를 완전히 잊어버린 것도 문제였습니다.
          • 재귀함수 혹은 맵 자료구조를 사용하는 방법도 생각해봤는데, 로직을 쉽게 표현할 방법을 찾지 못했었습니다.
      • 문제의 실마리는 테스트 케이스를 손 스케칭 노가다하면서 얻을 수 있었습니다.
        • 손으로 조합을 써내려가다 보니, 예전에 풀었던 완전 탐색 알고리즘 풀이가 스쳐갔습니다.
        • 그때 당시 저는 재귀함수 완전 탐색 알고리즘을 사용하였는데, 어느 고수분께서 비트맵으로 풀었던 기억이 있었습니다.
        • 이진적으로 상황을 표현할 수 있는 현재 상황에 비트맵이 적합하다 생각되서 비트맵을 적용하여 설계하였습니다.
#include <string>
#include <vector>
#include <algorithm>
#include <iostream>
using namespace std;


int pow(int num,int base){
    if(base == 0){
        return 1;
    }
    return num * pow(num,base-1);
}


vector<int> solution(vector<vector<int>> dice) {
    vector<int> answer;
    bitset<10> ownDice;

    int n = dice.size();
    int start = 0;
    int end = 0;

    int maxWin = 0;
    int caseNumber = 0;

    //initalize
    for(int i = 0; i < n/2; i++){
        start+=pow(2,i);
    }
    for(int i = n/2; i <n; i++ ){
        end+=pow(2,i);
    }
    //cout << "##start ,end " << start << " , " << end << endl;
    //cout << "###Debug : " << (start & (1<<1) )<< endl;
    //calculate case
    int maxSize = 10000;
    while(start<=end){
        int bitCnt = 0;
        for(int i = 0; i < n; i++){
            int temp = start & (1<<i) ;
            if(temp > 0) {
                bitCnt++;
            }
        }
        //cout << "start , bitCnt " << start << " , " << bitCnt << endl;
        if(bitCnt != n/2){
            start++;
            continue;
        }
        //cout << "###Debug  run" << endl;
        vector<int> a; a.reserve(10000);
        vector<int> b; b.reserve(10000);

        for(int i =0; i<n; i++){
            int temp = (start & (1<<i) );
            bool isOwnerA = temp > 0 ? true : false;

            vector<int>* pTarget = isOwnerA ? &a : &b;

            if(pTarget->size() == 0 ){
                for(int j = 0; j < 6; j++ ){
                    pTarget->push_back(dice[i][j]);
                }
                continue;
            }
            //insert befor sum
            int currentSize = pTarget->size();
            for(int k = 0; k < 6-1; k++){
                for(int m =0; m < currentSize; m++ ){
                    int v = (*pTarget)[m];
                    pTarget->push_back(v);
                }
            }
            //sum
            for(int k=0; k<6; k++){
                for(int m= 0; m < currentSize; m++){
                    int index = m+currentSize*k;
                    (*pTarget)[index]+=dice[i][k];
                }
            }
        }
        sort(a.begin(),a.end());
        sort(b.begin(),b.end());
        int win = 0;

        if(dice.size()==4 && start ==3){
            //cout << "[a] :";
            for(int temp : a){
                //cout << temp << " ";
            }
            //cout << endl;

            //cout << "[b] :";
             for(int temp : b){
                //cout << temp << " ";
            }
            cout << endl;
        }
        //statisic
        int lastWinIndex = 0;
        for(int i = 0; i <a.size(); i++){
            int j = lastWinIndex;
            for(j; j < b.size(); j++){
                if(a[i] <=b[j]) {
                    win+=j;
                    lastWinIndex = j;
                    break;
                }
            }
            if(j == b.size() && b[j] < a[i]){
                win+= (a.size() -i)*b.size();
                break;
            }
        }
        //cout << "Start : " << start << " , win : " << win << endl;
        if(maxWin < win){
            maxWin = win;
            caseNumber = start;
            //cout << "maxWin updated" << endl;
        }

        start++;
    }

    for(int i = 0; i < n; i++){
        int temp = caseNumber & (1 << i);
        if( temp > 0){
            answer.push_back(i+1);
        }
    }



    return answer;
}