본문 바로가기
PS/BOJ

[자바] 백준 14257 - XOR 방정식 (java)

by Nahwasa 2023. 8. 11.

목차

    문제 : boj14257

     

     

    필요 알고리즘

    • 조합론, 수학
      • 수학적인 직관이 필요한 문제이다.

    ※ 제 코드에서 왜 main 함수에 로직을 직접 작성하지 않았는지, 왜 Scanner를 쓰지 않고 BufferedReader를 사용했는지 등에 대해서는 '자바로 백준 풀 때의 팁 및 주의점' 글을 참고해주세요. 백준을 자바로 풀어보려고 시작하시는 분이나, 백준에서 자바로 풀 때의 팁을 원하시는 분들도 보시는걸 추천드립니다.

     

     

    풀이

      당연히 S와 X를 숫자 그 자체로 생각해보면 딱히 할 수 있는게 없다. 그리고 A+B=S야 그렇다 쳐도, A^B=X 는 어차피 비트단위로 연산되는 녀석이라, 비트단위로 봐야 뭔가 보일꺼라 생각했다. 그래서 비트로 변경해두고 보다보니 풀이가 보여서 풀게 되었다.

     

      일단 당연하게도 A^B의 결과인 X를 bit로 바꿨을 때 '1' 부분만 A와 B를 서로 다르게 구성할 수 있다. 따라서 이 문제의 답은 단순하게 X를 비트로(=2진수로) 바꿨을 때 '1'인 부분이 총 Y개였다면, 2의Y승이 답이다.

     

      근데 문제는 불가능한 경우를 찾아야 한다는 점이다. 그래서 해본게, 2진수로 X와 S를 변경해두고, A와 B의 비트가 어떤식으로 구성되어야 가능할지를 생각해봤다.

     

      예를들어 입력이 '9 5' 였다면 S=9, X=5 이다. 이걸 아래처럼 비트로 변경하고 1의 자리부터 높은 자리쪽으로 차례대로 살펴볼 때 어떤 경우가 가능하고 어떤 경우가 불가능한지 살펴본거다.

     

      이하 표에서 X는 현재 보고 있는 X의 비트값이다. S는 현재 보고 있는 S의 비트값이다. Cin은 이전 자리수로부터 carry가 0인지 1인지를 뜻한다. 이 3가지를 기준으로 봤으므로 총 8가지 경우가 있다. P는 가능 여부이다. O면 해당 X,S,Cin에 맞춰 A와 B의 해당 자리수를 만들 수 있음을 나타낸다. Cout은 해당 경우 다음 자리수로 carry가 넘어가는지를 나타낸다. 0||1 이면 넘길수도 있고 안넘길수도 있음을 나타낸다.

     

      그럼 이걸 코드로 나타내보면 아래와 같다. c는 Cin을 나타내고, return값이 Cout과 P를 합친 값인데, P가 'X'면 -1을 리턴하고 아니라면 Cout을 리턴한다. 이 때 0||1은 2로 리턴하기로 했다.

    private int chk(final int x, final int s, final int c) {
        if (x==0 && s==0) return c==1?-1:2;
        if (x==1 && s==0) return c==0?-1:1;
        if (x==1 && s==1) return c==1?-1:0;
        return c==0?-1:2;
    }

     

      그럼 이제 해줄건 위에서 말한대로 우측부터 좌측으로 진행하면서 

      하나씩 chk 함수를 돌려서 불가능한 경우 0을 출력하고 종료하는거다.

    가능한 경우, X의 해당 비트가 '1'인 경우만 카운팅 해준다. 

    for (int i = len-1; i >= 0; i--) {
        int xBit = x.charAt(i)-'0';
        int sBit = s.charAt(i)-'0';
    
        if (i==0 && c==1) {
            System.out.println(0);
            return;
        }
    
        c = chk(xBit, sBit, c);
        if (c == INVALID) {
            System.out.println(0);
            return;
        }
    
        if (xBit == 1) cnt++;
    }

     

      최종적으로 2의 cnt승이 답이된다. 이 때 S와 X가 동일한 수일 경우, A 또는 B가 0인 경우가 존재할 수 있다. 근데 문제에서 '모든 양의 정수 A, B' 라고 했으므로 이 경우를 빼줘야 한다.

    long ans = 1;
    while (cnt-->0) ans*=2;	// 2의 cnt승
    
    if (sNum==xNum) ans-=2;	// A=0 또는 B=0이 가능한 경우를 답에서 제외

     

     

    코드 : github

    import java.io.BufferedReader;
    import java.io.InputStreamReader;
    import java.util.StringTokenizer;
    
    public class Main {
        static BufferedReader br = new BufferedReader(new InputStreamReader(System.in), 1<<16);
        static final int INVALID = -1;
    
        public static void main(String[] args) throws Exception {
            new Main().solution();
        }
    
        public void solution() throws Exception {
            StringTokenizer st = new StringTokenizer(br.readLine());
            long sNum = Long.parseLong(st.nextToken());
            long xNum = Long.parseLong(st.nextToken());
            String s = Long.toBinaryString(sNum);
            String x = Long.toBinaryString(xNum);
    
            if (s.length() == x.length()) {
                s = lpad(s, 1);
                x = lpad(x, 1);
            } else if (s.length() > x.length()) {
                s = lpad(s, 1);
                x = lpad(x, s.length()-x.length());
            } else {
                x = lpad(x, 1);
                s = lpad(s, x.length()-s.length());
            }
    
            int len = s.length();
            int c = 0;
            int cnt = 0;
            for (int i = len-1; i >= 0; i--) {
                int xBit = x.charAt(i)-'0';
                int sBit = s.charAt(i)-'0';
    
                if (i==0 && c==1) {
                    System.out.println(0);
                    return;
                }
    
                c = chk(xBit, sBit, c);
                if (c == INVALID) {
                    System.out.println(0);
                    return;
                }
    
                if (xBit == 1) cnt++;
            }
    
            if (cnt == 0) {
                System.out.println(0);
                return;
            }
    
            long ans = 1;
            while (cnt-->0) ans*=2;
    
            if (sNum==xNum) ans-=2;
    
            System.out.println(ans);
        }
    
        private int chk(final int x, final int s, final int c) {
            if (x==0 && s==0) return c==1?-1:2;
            if (x==1 && s==0) return c==0?-1:1;
            if (x==1 && s==1) return c==1?-1:0;
            return c==0?-1:2;
        }
    
        private String lpad(String str, int gap) {
            StringBuilder sb = new StringBuilder();
            while (gap-->0) {
                sb.append('0');
            }
            sb.append(str);
            return sb.toString();
        }
    }

     

    댓글