poj 2155 二维树状数组或二维线段树

阅读量: searchstar 2019-08-21 00:30:13
Categories: Tags:

题目链接:<poj.org/problem?id=2155>

大意:一个初始值全为0的01矩阵,有两种操作,一种是将某个矩形区域的值都取反,另一种是查询某一单点的值。

受到另一题启发:https://seekstar.github.io/2019/07/08/codeforces-1184-c2-%E6%9B%BC%E5%93%88%E9%A1%BF%E8%B7%9D%E7%A6%BB%E8%BD%AC%E5%88%87%E6%AF%94%E9%9B%AA%E5%A4%AB%E8%B7%9D%E7%A6%BB%EF%BC%8C%E7%9F%A9%E5%BD%A2%E8%A6%86%E7%9B%96%E6%9C%80%E5%A4%A7%E7%82%B9%E6%95%B0%E8%BD%AC%E6%9C%80%E5%A4%A7%E5%89%8D%E7%BC%80/

我们定义另一个矩阵C,使得原矩阵的某一个点的值为C的前缀异或,那么如果我们要将原矩阵的左下角为(x1, y1), 右上角为(x2, y2)的矩形区域的值取反,只需要将C矩阵的点(x1, y1), (x1, y2 + 1), (x2 + 1, y1), (x2 + 1, y2 + 1)的值取反即可。

代码:

#include <cstdio>
#include <cctype>

using namespace std;

#define DEBUG 0
#define ONLINE_JUDGE

int lowbit(int x) {
	return x & -x;
}
template <typename T>
struct BIT_2DIM {
	const static int maxn1 = 1010, maxn2 = 1010;
	T s[maxn1][maxn2];
	int n1, n2;

	void Init(int _n1, int _n2) {
		n1 = _n1;
		n2 = _n2;
		for (int i = 0; i <= n1; ++i)
			for (int j = 0; j <= n2; ++j)
				s[i][j] = 0;
	}
	void Add(int x1, int x2, T v) {
		for (int t = x2; x1 <= n1; x1 += lowbit(x1))
			for (x2 = t; x2 <= n2; x2 += lowbit(x2))
				s[x1][x2] ^= v;
	}
	T Sum(int x1, int x2) {
		T ans = 0;
		for (int t = x2; x1; x1 ^= lowbit(x1))
			for (x2 = t; x2; x2 ^= lowbit(x2))
				ans ^= s[x1][x2];
		return ans;
	}
};

int main() {
#ifndef ONLINE_JUDGE
	freopen("in.in", "r", stdin);
	freopen("out.out", "w", stdout);
#endif
	int T;
	static BIT_2DIM<bool> bit;
	bool first = true;

	scanf("%d", &T);
	while (T--) {
		if (first) first = false;
		else putchar('\n');

		int n, m;

		scanf("%d%d", &n, &m);
		bit.Init(n + 5, n + 5);
		while (m--) {
			getchar();
			char op = getchar();
			if ('C' == op) {
				int x1, y1, x2, y2;
				scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
				bit.Add(x1, y1, 1);
				bit.Add(x1, y2 + 1, 1);
				bit.Add(x2 + 1, y1, 1);
				bit.Add(x2 + 1, y2 + 1, 1);
			} else {
				int x, y;
				scanf("%d%d", &x, &y);
				printf("%d\n", bit.Sum(x, y));
			}
		}
	}
	
	return 0;
}

另外有一个二维线段树写法,这里用永久标记实现

#include <cstdio>
#include <cstring>
#include <cctype>

using namespace std;

#define MAXN1 1011
#define MAXN2 1011

inline int ls(int rt) {
    return rt << 1;
}
inline int rs(int rt) {
    return rt << 1 | 1;
}
struct SGT {
    const static int max_node = MAXN2 << 2;
    bool s[max_node];

    void Init(int n) {
        memset(s, 0, (n<<2) * sizeof(s[0]));
    }
    void Reverse(int rt, int l, int r, int L, int R) {
        if (L <= l && r <= R) {
            s[rt] = !s[rt];
        } else {
            int mid = (l + r) >> 1;
            if (L <= mid) Reverse(ls(rt), l, mid, L, R);
            if (mid < R) Reverse(rs(rt), mid+1, r, L, R);
        }
    }
    bool Query(int rt, int l, int r, int x) {
        if (l == r) {
            return s[rt];
        } else {
            int mid = (l + r) >> 1;
            if (x <= mid) return s[rt] ^ Query(ls(rt), l, mid, x);
            else return s[rt] ^ Query(rs(rt), mid+1, r, x);
        }
    }
};
struct SGT2 {
    const static int max_node = MAXN1 << 2;
    SGT s[max_node];
    int n2;

    void Init(int n1, int _n2) {
        n2 = _n2;
        for (int i = 0; i <= (n1 << 2); ++i)
            s[i].Init(n2);
    }
    void Reverse(int rt, int l, int r, int L1, int R1, int L2, int R2) {
        if (L1 <= l && r <= R1) {
            s[rt].Reverse(1, 1, n2, L2, R2);
        } else {
            int mid = (l + r) >> 1;
            if (L1 <= mid) Reverse(ls(rt), l, mid, L1, R1, L2, R2);
            if (mid < R1) Reverse(rs(rt), mid+1, r, L1, R1, L2, R2);
        }
    }
    bool Query(int rt, int l, int r, int x1, int x2) {
        bool ans = s[rt].Query(1, 1, n2, x2);
        if (l != r) {
            int mid = (l + r) >> 1;
            if (x1 <= mid) ans ^= Query(ls(rt), l, mid, x1, x2);
            else ans ^= Query(rs(rt), mid+1, r, x1, x2);
        }
        return ans;
    }
};
char GetChar() {
    char ch;
    while (isspace(ch = getchar()));
    return ch;
}
int main() {
    int T;
    static SGT2 sgt2;

    scanf("%d", &T);
    bool first = true;
    while (T--) {
        if (first) first = false;
        else putchar('\n');

        int n, m;

        scanf("%d%d", &n, &m);
        sgt2.Init(n, n);
        while (m--) {
            char op = GetChar();
            if (op == 'Q') {
                int x, y;
                scanf("%d%d", &x, &y);
                putchar(sgt2.Query(1, 1, n, x, y) + '0');
                putchar('\n');
            } else {
                int x1, y1, x2, y2;
                scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
                sgt2.Reverse(1, 1, n, x1, x2, y1, y2);
            }
        }
    }
    return 0;
}