【算法】多项式合集

开始&多项式乘法

之前学了这么多次多项式,这次用latex写下
首先是fft。
fft的目的是快速的将一个n次多项式转化成点值表示法,行如\((x_{i},y_{i})\)这样的,这有什么用呢?
对于两个点值表示的多项式,有\((x,y_{a})\)和\((x,y_{b})\)这两个点,显然,这两个多项式乘起来会得到一个新点(x,y_{a}y_{b})也就说如果我们能转化成点值表示法,那么我们就可以在O(n)的时间里面完成多项式乘法。
但是如果我们只是傻傻的去选择n个点然后暴力去算多项式的值,这样复杂度就会为O(n^{2}),所以FFT就可以派上用场了。
我们首先需要前置知识复数,基本运算见数学课本。这里只说单位根,借鉴下自为风月马前卒dalao的话,在复平面上,以原点为圆心,1为半径作圆,所得的圆叫单位圆。以圆点为起点,圆的n等分点为终点,做n个向量,设幅角为正且最小的向量对应的复数为\omega_{n},称为n次单位根。按照顺时针以此为\omega_{n}^{0},\omega_{n}^{1},\omega_{n}^{2}……\omega_{n}^{n-1}。简单来说就是把圆分成n份的向量。
欧拉公式:\omega_{n}^{k} = cos(k\frac{2\pi}{n}) + i sin(k\frac{2\pi}{n})
那么画个图,我们能得到几条比较显然的性质
1.\omega_{n}^{k} = \omega_{2n}^{2k},这个你把圆从原来n份分成了2n份,显然你原来那个单位根到了2k上。
2.\omega_{n}^{k + \frac{n}{2}} = – \omega_{n}^{k},显然你转了半圈回来肯定是与原来相反的。
那么这两个结论的重要性将会在下面体现。
现在我们有一个n次不等式
A(x) = \sum_{i = 0}^{n – 1} a_{i}x^{i} = (a_{0} + a_{2}x^{2} + a_{4}x^{4} + ……a_{n-2}x^{n-2}) \\ + (a_{1}x + a_{3}x^{3} + a_{5}x^{5} + …… + a_{n – 1}x^{n – 1})
我们按照奇偶性划分出两个新的不等式:
A_{1}(x) = a_{0} + a_{2}x + a_{4}x^{2} + …… + a_{n – 2}x^{\frac{n – 2}{2}}
A_{2}(x) = a_{1} + a_{3}x + a_{5}x^{2} + …… + a_{n – 1}x^{\frac{n – 1}{2}}
那么我们就可以得出,A(\omega_{n}^{k}) = A_{1}(\omega_{n}^{2k}) + \omega_{n}^{k}A_{2}(\omega_{n}^{2k})
同理,A(\omega_{n}^{k + \frac{n}{2}}) = A_{1}(\omega_{n}^{2k + n}) + \omega_{n}^{k + \frac{n}{2}}A_{2}(\omega_{n}^{2k + n})
根据我们刚才得到的单位根的性质,我们可以将这个式子继续变形。
A(\omega_{n}^{k + \frac{n}{2}}) = A_{1}(\omega_{n}^{2k}) – \omega_{n}^{k}A_{2}(\omega_{n}^{2k})
可以发现这两个式子除了中间的加减号以外没有变化,因此只要知道A_{1}(\omega_{n}^{2k}),\omega_{n}^{k}A_{2}(\omega_{n}^{2k})我们就可以将两个值计算出来了,也就是说,这个问题可以拆分成两个子问题,所以这是个类似于二分的过程,我们只需要倍增的计算这个东西就可以了,我们利用单位根的性质将O(n^{2})的问题转化成O(nlog_{n})了。
那么经过一次正向变换之后,我们顺利的得到了点值表达式,那么多项式乘法也就非常好做了。但是因为一般我们需要系数的表达式,因此我们还需要一次逆变换,那么逆变换怎么做呢。
逆变换……我不是很懂为什么有人会觉得很显然啊,这东西并不显然的嘛。至少我复习了一个下午,结合算导和lgj大爷的博客勉强看懂了。
我们考虑刚才得到了一个点权表示,现在我们将他再次写成一个新的多项式,设B(x) = \sum_{i = 0}^{n – 1}y_{i}x^{i},其中y是我们刚才得到的点权表示。所以我们依次代入\omega_{n}^{0},\omega_{n}^{-1},\omega_{n}^{-2}……,\omega_{n}^{-(n – 1)}
我们现在对于式子进行整理。
\begin{aligned} B(\omega_{n}^{k}) & = \sum_{i = 0}^{n – 1}(\omega_{n}^{k})^{i}\sum_{j = 0}^{n – 1}a_{j}(\omega_{n}^{i})^{j} \\ & = \sum_{i = 0}^{n – 1}(\omega_{n}^{k})^{i}\sum_{j = 0}^{n – 1}a_{j}(\omega_{n}^{j})^{i} \\ & = \sum_{i = 0}^{n – 1} \sum_{j = 0}^{n – 1}a_{j}(\omega_{n}^{j + k})^{i} \\ & = \sum_{j = 0}^{n – 1}a_{j} \sum_{i = 0}^{n – 1}(\omega_{n}^{j + k})^{i} \end{aligned}
然后我们先看后面这个式子,考虑到这显然是个等比数列,假设j+k不等于0,我们直接套用公式S(\omega_{n}^{j + k}) = \frac{1 – (\omega_{n}^{j + k})^{n}}{1 – \omega_{n}^{j + k}} = \frac{1 – (\omega_{n}^{n})^{j + k}}{1 – \omega_{n}^{j + k}} = \frac{0}{1 – \omega_{n}^{j + k}}
对于j+k为0的情况,显然S(\omega_{n}^{j + k}) = n
那么也就说,B(\omega_{n}^{k}) = na_{-k},a_{-k} = \frac{B(\omega_{n}^{k})}{n}
考虑k进去的都是负数emmmm。
再回过头看B,这个式子可以通过fft快速得到,因此我们知道了逆变换的操作办法。
ntt和这个差不多,重点在与他把单位根换成了大质数。我们考虑单位根的性质,你要找到一个数a使得,a^{x} \equiv 1 (mod p),并且这个x唯一。那么首先=1这个可以让我们想到费马小定理/欧拉定理,大多数时候这个p都是个质数。所以我们可以想到原根。原根的定义是g^{P-1} \equiv 1 (mod P)是否当且仅当指数为P-1的时候成立。然后我们现在因为把多项式的长度变成了2的幂次,所以我们需要一个质数p满足p = k2^{m} + 1。所以很多时候我们会看见ntt的题目模数是998244353,因为998244353 = 2 * 17 * 2^{23} + 1,对于这个数3是它的原根。
luoguP1919 【模板】A*B Problem升级版(fft)

// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const int ms = 2e5;

int n, len = 1, num, pos[ms], ans[ms];

struct in
{
    double r, l;
    in(double rr = 0.0, double ll = 0.0)
    {
        r = rr, l = ll;
    }
    inline in operator + (in x)
    {
        return in(r + x.r, l + x.l);
    }
    inline in operator - (in x)
    {
        return in(r - x.r, l - x.l);
    }
    inline in operator * (in x)
    {
        return in{r * x.r - l * x.l, r * x.l + l * x.r};
    }
}a[ms], b[ms];

inline void fft(in *a, int tp)
{
    for(ri i = 0; i <= len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    in la, lb;
    for(ri i = 1; i < len; i <<= 1)
    {
        in wn(cos(pi / i), tp * sin(pi / i));
        for(ri j = 0; j < len; j += (i << 1))
        {
            in w(1, 0);
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i];
                w = w * wn, a[k] = la + lb, a[k + i] = la - lb;
            }
        }
    }
}

int main()
{
    re(n); char la; int ta = n;
    while(!isdigit(la = getchar()));
    while(isdigit(la))
        a[-- ta].r = la - '0', a[ta].l = 0.0, la = getchar();
    while(!isdigit(la = getchar()));
    ta = n;
    while(isdigit(la))
        b[-- ta].r = la - '0', b[ta].l = 0.0, la = getchar();
    while(len < n + n)
        len <<= 1, num ++;
    for(ri i = n; i < len; i ++)
        a[i].r = a[i].l = b[i].l = b[i].r = 0.0;
    for(ri i = 0; i <= len; i ++)
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    fft(a, 1), fft(b, 1);
    for(ri i = 0; i < len; i ++)
        a[i] = a[i] * b[i];
    fft(a, -1); ta = n + n - 1;
    for(ri i = 0; i < len; i ++)
        ans[i] = a[i].r / len + 0.5;
    for(ri i = 0; i < len; i ++)
        ans[i + 1] += ans[i] / 10, ans[i] %= 10;
    ta = n + n - 1;
    while(ta >= 0 && ans[ta] <= 0)
        ta --;
    for(ri i = ta; i >= 0; i --)
        putchar(ans[i] + '0');
    system("pause");
}

还有个ntt版本的

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = gch()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = gch();
    if(b == 1)
        x = - x;
}

const int ms = 2e5;

const lo mo = 998244353;

lo n, len = 1, ans[ms], num, pos[ms], a[ms], b[ms], po[ms], inv[ms];

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

inline void ntt(lo *a, lo tp)
{
    for(ri i = 0; i < len; i ++)
        if(i < pos[i])
            swap(a[pos[i]], a[i]);
    lo la, lb;
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
}

int main()
{
    re(n); char la; int ta = n;
    while(!isdigit(la = gch()));
    while(isdigit(la))
        a[-- ta] = la - '0', la = gch();
    while(!isdigit(la = gch()));
    ta = n;
    while(isdigit(la))
        b[-- ta] = la - '0', la = gch();
    while(len <= n + n)
        len <<= 1, num ++;
    lo inv3 = ksm(3, mo - 2);
    for(ri i = 1; i <= len; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));//注意这里一定要左移1
    for(ri i = 1; i <= len; i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    for(ri i = 1; i < len; i ++)
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(a, 1), ntt(b, 1);
    for(ri i = 0; i < len; i ++)
        a[i] = a[i] * b[i] % mo;
    ntt(a, -1); inv3 = ksm(len, mo - 2);
    for(ri i = 0; i < len; i ++)
        a[i] = a[i] * inv3 % mo;
    ta = n + n - 1;
    for(ri i = 0; i < len; i ++)
        ans[i] = a[i];
    for(ri i = 0; i < len; i ++)
        ans[i + 1] += ans[i] / 10, ans[i] %= 10;
    ta = n + n - 1;
    while(ta >= 0 && ans[ta] <= 0)
        ta --;
    for(ri i = ta; i >= 0; i --)
        putchar(ans[i] + '0');
    system("pause");
}

分治fft/ntt

现在我们有两个多项式f(x) = \sum_{i = 1}^{x} f(x – i) g(i),g是给定的多项式,边界是f(0) = 1。这并不是一个直接的卷积形式,所以我们不能直接用多项式的乘法来解决这个问题,因此我们需要考虑有没有其他的做法可以快速实现。
反正这里写着分治ntt了,所以显然我们要考虑分治啊。对于[l,mid]这个区间,我们假设这个区间已经是算完的了,考虑[l, mid]这个区间对于后面[mid + 1, r]的贡献,设贡献V(x)表示前面的区间对于x号位置的贡献,那么V(x) = \sum_{i = l}^{mid}f(i)g(x – i),这是一个显然的事情。这个形式我们发现,已经不存在未知数了,可以直接卷积,也就说,这个是可以用fft算出来的。
那么我们的思路就很明显了,优先计算左侧区间的结果,然后通过左侧区间去给右边增加贡献,这样右边最开头又被完全计算了,那么右边的最开头又会把它计算的结果贡献给它右边。每次折半,所以一共有logn层,而每层nlogn,所以总复杂度是nlog^{2}n的。
luoguP4721【模板】分治 FFT(明明是ntt板子啊,只是fft太出名了吧qwq)

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 2e5 + 10;

const lo mo = 998244353;

lo n, f[ms << 2], pos[ms << 2], g[ms << 2], po[ms << 2], inv[ms << 2], x[ms << 2], y[ms << 2];

inline lo ksm(lo x, lo k)
{
    lo a = x, rt = 1;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt; 
}

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo l, lo r)
{
    if(l == r)
        return;
    lo mid = (l + r) >> 1; merge(l, mid);//类似cdq
    lo len = 1, num = 0;
    while(len <= (r - l + 1))
        len <<= 1, num ++;
    for(ri i = 1; i < len; i ++)
        x[i] = y[i] = 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    x[0] = y[0] = x[len] = y[len] = 0;
    for(ri i = l; i <= mid; i ++)
        x[i - l] = f[i];
    for(ri i = 1; i <= r - l; i ++)
        y[i - 1] = g[i];
    ntt(x, len, 1), ntt(y, len, 1);
    for(ri i = 0; i <= len; i ++)
        x[i] = x[i] * y[i] % mo;
    ntt(x, len, -1);
    for(ri i = mid + 1; i <= r; i ++)//把贡献直接加上去
        f[i] += x[i - l - 1] % mo, f[i] %= mo;
    merge(mid + 1, r);//再考虑右边
}

int main()
{
    re(n), f[0] = 1;
    for(ri i = 1; i < n; i ++)
        re(g[i]), g[i] %= mo;
    lo inv3 = ksm(3, mo - 2);
    for(ri i = 1; i <= (n + n); i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));
    for(ri i = 1; i <= (n + n); i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    merge(0, n - 1);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", f[i]);
    system("pause");
}

当然,这个题目还有不用分治更为巧妙的办法,也就说多项式求逆,接下来,再详细叙述。

多项式求逆

现在有一个多项式F(x),我们想要求它在mod x^{n}意义下的逆元,即
F(x) * G(x) \equiv 1 (mod \ x^{n})
那么我们设H'(x)mod \ x^{\frac{n}{2}}的时候F(x)的逆元,也就说H(x)’F(x) \equiv 1 (mod \ x^{\frac{n}{2}})
我们现在再设H(x)是在mod \ x^{n}意义下的逆元,所以也就说
H(x) – H'(x) \equiv 0 (mod \ x^{\frac{n}{2}})
我们进一步将这个式子平方后再乘上F(x)
F(x)(H(x)’)^{2} – 2H'(x) + H(x) \equiv 0 (mod \ x^{n})
移项可得H(x) \equiv 2H'(x) – F(x)(H'(x))^{2} (mod \ x^{n})
通过这个式子,我们递归去计算多项式,边界是最后考虑到x^{0},这时显然逆元是直接快速幂就好,注意对于多项式来说随着幂次的不断变小,多项式长度也折半了。多项式求逆元的复杂度是O(nlogn)
luoguP4238【模板】多项式求逆

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 4e5;

const lo mo = 998244353;

lo n, a[ms], b[ms], c[ms], pos[ms], po[ms], inv[ms];

inline lo ksm(lo x, lo k)
{
    lo a = x, rt = 1;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 0; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, iv = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * iv % mo;
}

void merge(lo len)
{
    if(len == 1)
    {
        b[0] = ksm(a[0], mo - 2); return;
    }
    merge((len + 1) >> 1);
    lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? a[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(c, x, 1), ntt(b, x, 1);
    for(ri i = 0; i < x; i ++)
        b[i] = ((2ll - c[i] * b[i] % mo) % mo + mo) % mo * b[i] % mo;
    ntt(b, x, -1);
    for(ri i = len; i < x; i ++)
        b[i] = 0;
}

int main()
{
    re(n);
    for(ri i = 0; i < n; i ++)
        re(a[i]);
    lo inv3 = ksm(3, mo - 2);
    for(ri i = 1; i <= n + n; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));
    for(ri i = 1; i <= n + n; i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    merge(n);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", b[i]);
    system("pause");
}

那么还记得刚才我们说的那个分治ntt可以用多项式求逆代替的事情吗?
我们首先通过生成函数的想法得到一个新的多项式
F(x) = \sum_{i = 0}^{+\infty}f(i)x^{i},G(x) = \sum_{i = 0}^{+\infty}g(i)x^{i}
其实也没改变什么,就是加了个x,换成了系数而已,当然了,正无穷很重要。因为有了正无穷我们才能做接下来的操作。
那么我们将两个多项式卷起来
F(x)G(x) = \sum_{i = 0}^{+\infty}x^{i}\sum_{j + k = i}f(j)g(k) = F(x) – f(0)
之所以最后是减去f(0),是因为g(0)不存在,所以为0,而重新转化成F(x)则是因为根据之前多项式的定义
f(x) = \sum_{i = 1}^{x} f(x – i) g(i)
在无穷的时候这个东西就和F(x)没什么区别了
那么也就是说F(x)G(x) \equiv F(x) – f(0) (mod p)
移项可得F(x) \equiv \frac{f(0)}{1 – G(x)}
这个形式也就是逆元的形式,也就说,我们现在把这个式子转化成了求逆元,那么套用多项式求逆元的模板就行了,唯一需要注意的是,那个1是常数项。所以我们只需要让g0为1即可。

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 4e5;

const lo mo = 998244353;

lo n, f[ms], g[ms], pos[ms], po[ms], inv[ms], h[ms];

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo len)
{
    if(len == 1)
    {
        f[0] = ksm(g[0], mo - 2); return;
    }
    merge((len + 1) >> 1); lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        h[i] = (i < len) ? g[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(h, x, 1), ntt(f, x, 1);
    for(ri i = 0; i < x; i ++)
        f[i] = ((2ll - f[i] * h[i] % mo) % mo + mo) * f[i] % mo;
    ntt(f, x, -1);
    for(ri i = len; i < x; i ++)
        f[i] = 0;
}

int main()
{
    re(n); lo inv3 = ksm(3, mo - 2); g[0] = 1;
    for(ri i = 1; i < n; i ++)
        re(g[i]), g[i] = mo - g[i];
    for(ri i = 1; i <= n + n; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));
    for(ri i = 1; i <= n + n; i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    merge(n);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", f[i]);
    system("pause");
}

多项式除法

其实学过求逆元以后这个东西就不是很难做了,现在我们先定义一个操作A^{R}(x) = x^{n}A(\frac{1}{x}),这个操作我们来试着手算下,看看代表什么意义。
A(x) = x^{2} + 2x + 3,A^{R}(x) = 1 + 2x + 3x^{2},也就说,所有的系数发生了一个翻转。问题得到了转化。
那么现在,我们设A(x) = B(x)C(x) + D(x),其中A(x)次数为n,B(x)次数为m,C(x)次数不超过为n-m,D(x)次数不会大于m。
也就说A(x)是被除数,B(x)是除数,C(x)是商,D(x)是余数。
显然,两边同乘一个x^{n}并将括号里的x换成\frac{1}{x},式子依旧成立。
x^{n}A(\frac{1}{x}) = x^{m}B(\frac{1}{x})x^{n – m}C(\frac{1}{x}) + x^{n-m+1}x^{m-1}D(x)
而根据我们刚才的定义可以得知,A^{R}(x) = x^{n}A(\frac{1}{x}),也就说现在我们可以将式子写成这样的形式。
A^{R}(x) = B^{R}(x)C^{R}(x) + x^{n-m+1}D^{R}(x)
我们现在如果想消掉那个烦人的D^{R}(x),很显然可以将这个式子放在mod \ x^{n-m+1}意义下。
这个式子最后可以写成A^{R}(x) \equiv B^{R}(x)C^{R}(x) \ (mod \ x^{n-m+1})
继续左右两侧同除以B^{R}(x)
C^{R}(x) \equiv \frac{A^{R}(x)}{B^{R}(x)} \ (mod \ x^{n-m+1})
我们将这个问题转化成了多项式求逆,求完逆再把C^{R}(x)的系数翻转过来就行了,D(x)也很好求
D(x) = A(x) – B(x)C(x)
luoguP4512【模板】多项式除法(注意一定要严格按照刚才那些次方数)

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 3e5;

const lo mo = 998244353;

lo n, m, a[ms], b[ms], c[ms], d[ms], e[ms], f[ms], g[ms], h[ms], pos[ms], po[ms], inv[ms];

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo len)
{
    if(len == 1)
    {
        b[0] = ksm(e[0], mo - 2); return;
    }
    merge((len + 1) >> 1); lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        g[i] = (i < len) ? e[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(g, x, 1), ntt(b, x, 1);
    for(ri i = 0; i < x; i ++)
        b[i] = ((2ll - g[i] * b[i] % mo) % mo + mo) % mo * b[i] % mo;
    ntt(b, x, -1);
    for(ri i = len; i < x; i ++)
        b[i] = 0;
}

int main()
{
    re(n), re(m);
    for(ri i = 0; i <= n; i ++)
        re(a[i]), h[i] = a[i];
    for(ri i = 0; i <= m; i ++)
        re(e[i]), f[i] = e[i];
    reverse(a, a + n + 1), reverse(e, e + m + 1);
    lo len = 1, num = 0, inv3 = ksm(3, mo - 2);
    while(len <= (n + m))
        len <<= 1;
    for(ri i = 1; i <= len; i ++)
        po[i] = ksm(3, (mo - 1) / (i << 1)), inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    len = 1;
    while(len <= (n + n - m + 2))
        len <<= 1, num ++;
    merge(n - m + 1);
    for(ri i = 0; i < len; i ++)
        a[i] = (i < n + 1) ? a[i] : 0, b[i] = (i < n - m + 1) ? b[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(a, len, 1), ntt(b, len, 1);
    for(ri i = 0; i < len; i ++)
        c[i] = a[i] * b[i] % mo;
    ntt(c, len, -1), ntt(a, len, -1);
    reverse(c, c + n - m + 1);
    for(ri i = 0; i <= n - m; i ++)
        printf("%lld ", c[i]);
    printf("\n");
    len = 1, num = 0;
    while(len <= (n + 2))
        len <<= 1, num ++;
    for(ri i = 0; i < len; i ++)
        f[i] = (i < m + 1) ? f[i] : 0, c[i] = (i < n - m + 1) ? c[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(f, len, 1), ntt(c, len, 1);
    for(ri i = 0; i < len; i ++)
        c[i] = c[i] * f[i] % mo;
    ntt(c, len, -1);
    for(ri i = 0; i < len; i ++)
        d[i] = ((h[i] - c[i]) % mo + mo) % mo;
    for(ri i = 0; i < m; i ++)
        printf("%lld ", d[i]);
}

拆系数fft/mtt

这个东西我花了一天的时间去看论文和代码理解qwq,特别感谢MikuNotFoundException,给我细致耐心的讲解了mtt,不然也就没mtt这个部分了。强烈推荐2016年国家集训队论文再探快速傅里叶变换,myy太强了qwq。
有的时候,题目给我们的模数并不是形如998244353这么标准的可以用3为原根的数,他可能是任何一个整数,序列长度1e5,模数大小1e9,这个时候我们直接ntt已经gg了,而直接fft的话得到的结果可能最大在1e23,这是直接要gg的节奏,那么我们怎么办。
还是从fft入手,我们考虑怎么样才能降低那个1e23的级别,既然叫拆系数,那么肯定是要拆系数的。我们现在设存在F(x),G(x),使得存在
F(x) = D(x)P + C(x),G(x) = B(x)P + A(x)
现在我们要求F(x),G(x)卷积。
首先介绍下变量,P在这里表示的是我选择的一个新的数,我们为了降低次数,从多项式中提出来个级别在题目给我们的模数最大范围根号,假设题目给我们模数p,那么我们这个最大P=\sqrt{p}。这是为了通过提出因数来降低我们的级别。而A(x),B(x)分别是我们提出模数p之后的系数和余数。注意,这个地方的A(x),B(x)和论文里利用共轭复数优化那里的A(x),B(x)的意义是一样的。论文中的A(x),B(x)并不是我们最后想要卷起来的结果,而是单纯的系数。为什么要求系数待会再说。
(D(x)P + C(x))(B(x)P + A(x)) = B(x)D(x)P^{2} + (B(x)C(x) + A(x)D(x))p + A(x)C(x)
对于这个式子,A(x),B(x),C(x),D(x)都没有超过\sqrt{p}这么个级别,也就说这个式子被限制在了p这么个级别,而p在1e9左右,所以我们能够承受了。所以我们A(x),C(x)可以从原来的多项式%P得到,那么B(x),D(x)至于要把A(x),C(x)这几位划掉就行。一般我们把P=32767
现在去括号之后最显然的想法是分别对A(x),B(x),C(x),D(x)进行dft,然后多项式乘法后再分别对B(x)D(x),(B(x)C(x) + A(x)D(x)),A(x)C(x)这三个进行一次idft,这样我们通过暴力的七次dft/idft得到了结果。这就是最暴力的mtt。
但是这个东西是可以优化的,在那个论文里,最重要的部分就是如何优化这个过程,我应该还是没有优化到最好的情况。我只是优化到了4次dft。
为了接下来方便,dft/idft合称为dft好了。我们首先为了方便,将A(x),B(x),C(x),D(x)这些按照复数构造起来,B(x),D(x)做虚数部,设
G'(x) = A(x) + iB(x), F'(x) = C(x) + iD(x)
那么既然是复数,为何我不能直接将转化成复数的多项式直接卷积呢。原因是,这个复数多项式只是我们用来快速计算系数的一个工具,他并不是我们所得到的直接结果。原来的式子里面还有个P,P^{2},这些我们需要求他们的系数,直接卷积都混在一起了,怎么求系数。也就说,我们至少需要去掉一个括号。但是去掉两个括号我们就没法优化了,因此我们只将式子化成
(D(x)P + C(x))A(x)+(D(x)P + C(x))B(x)
这样,现在(D(x)P + C(x))需要一次dft
A(x),B(x),(D(x)P + C(x))A(x),(D(x)P + C(x))B(x)
这四个分别需要一次dft,我们现在将7次优化到了5次。可是5次到4次又是怎么实现的呢。
这个时候我们需要用到论文上所讲的部分了,设
P(x) = A(x) + iB(x), Q(x) = A(x) – iB(x)
这里其实P(x)和G'(x)是一样的,所以这也就是我们代码里面直接将第二个多项式转化成复数的新多项式看作是P(x)的原因,那么我们先得出个结论。
P(\omega_{n}^{k}) = conj(Q(\omega_{n}^{-k})) = conj(Q(\omega_{n}^{n-k}))
conj表示的是取共轭复数,共轭复数……可以简单的理解为,以x轴为对称轴,你对你的单位根画了对称图形。
图片源自百度百科。
那么也就说虚数部变成负值了,可是Q里面虚数部本来也就带个负号,负负得正,抵消了,而根据单位根的性质,我加个n也是不会影响大小的。至于具体展开后化简,左转论文吧2333333,写不动了。
这个东西意味着什么呢,意味着,我可以通过一次dft将P(x),Q(x)同时求出来——而根据我们开始所设,可以整理得到两个式子
A(x) = \frac{P(x)+Q(x)}{2}
B(x) = \frac{P(x)-Q(x)}{2i}
一次dft我们就可以求出A(x),B(x)了,然后我们再按照5倍dft的做法,乘起来idft就好了。
luoguP4245【模板】任意模数NTT

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 3e5;

const lo mo = (1 << 15) - 1;

const ld pi = acosl(-1.0); 

struct in
{
    ld r, i;
    in(ld rr = 0.0, ld ii = 0.0)
    {
        r = rr, i = ii;
    }
    inline in conj()
    {
        return (in){r, -i};
    }
}a[ms], p[ms], q[ms], e[ms], f[ms], ae[ms], af[ms];

inline in operator + (in a, in b)
{
    return (in){a.r + b.r, a.i + b.i};
}

inline in operator - (in a, in b)
{
    return (in){a.r - b.r, a.i - b.i};
}

inline in operator * (in a, in b)
{
    return (in){a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r};
}

inline in operator / (in a, in b)
{
    return (in){(a.r * b.r + a.i * b.i) / (pow(b.r, 2) + pow(b.i, 2)), (a.i * b.r - a.r * b.i) / (pow(b.r, 2) + pow(b.i, 2))};
}

lo n, m, pp, ax[ms], bx[ms], pos[ms];

inline void fft(in *a, lo len, ld tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    in la, lb;
    for(ri i = 1; i < len; i <<= 1)
    {
        for(ri j = 0; j < len; j += (i << 1))
        {
            for(ri k = 0; k < i; k ++)
            {
                in w(cos(pi / i * k), tp * sin(pi / i * k));
                la = a[j + k], lb = w * a[j + k + i];
                a[j + k] = la + lb, a[j + k + i] = la - lb;
            }
        }
    }
}

int main()
{
    //freopen("a.in", "r", stdin);
    //freopen("a.out", "w", stdout);
    re(n), re(m), re(pp);
    for(ri i = 0; i <= n; i ++) 
        re(ax[i]), ax[i] %= pp;
    for(ri i = 0; i <= m; i ++)
        re(bx[i]), bx[i] %= pp;
    lo len = 1, num = 0;
    while(len <= (n + m))
        len <<= 1, num ++;
    for(ri i = 0; i < len; i ++)
    {
        a[i] = (in){ax[i] & mo, ax[i] >> 15};//拆成复数 
        p[i] = (in){bx[i] & mo, bx[i] >> 15};
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    }
    fft(a, len, 1), fft(p, len, 1);
    for(ri i = 0; i < len; i ++)//通过p求q 
        q[i] = p[(len - i) & (len - 1)].conj();
    in g1(0, 2), g2(2, 0);
    for(ri i = 0; i < len; i ++)//e是博客里面整理出来的a,f是b 
        e[i] = (p[i] + q[i]) / g2, f[i] = (p[i] - q[i]) / g1;
    for(ri i = 0; i < len; i ++)//然后如5倍dft所说,将两个式子分别乘起来 
        ae[i] = a[i] * e[i], af[i] = a[i] * f[i];
    fft(ae, len, -1), fft(af, len, -1);
    for(ri i = 0; i <= n + m; i ++)//最后把系数带进去,注意除完要膜,不膜会炸(滑稽) 
        printf("%lld ", ((((lo)(ae[i].r / len + 0.5)) % pp 
                    + (((lo)(ae[i].i / len + 0.5) % pp) << 15) % pp
                    + (((lo)(af[i].r / len + 0.5) % pp) << 15) % pp
                    + (((lo)(af[i].i / len + 0.5) % pp) << 30) % pp) % pp + pp) % pp);
}

luoguP4239【模板】多项式求逆(加强版),开了o2过的,丢人

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int
#define pi acos(-1.0)

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

typedef unsigned long long ulo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 4e5 + 4e4, mo = 32767, inf = 1e9 + 7;

struct in
{
    double r, i;
    in(double rr = 0.0, double ii = 0.0)
    {
        r = rr, i = ii;
    }
    inline in conj()
    {
        return (in){r, -i};
    }
}ap[ms], p[ms], q[ms], f[ms], g[ms], apf[ms], apg[ms];

lo n, s[ms], e[ms], c[ms], d[ms], pos[ms];

inline in operator + (in a, in b)
{
    return (in){a.r + b.r, a.i + b.i};
}

inline in operator - (in a, in b)
{
    return (in){a.r - b.r, a.i - b.i};
}

inline in operator * (in a, in b)
{
    return (in){a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r};
}

inline in operator / (in a, in b)
{
    double la = powl(b.r, 2) + powl(b.i, 2);
    return (in){(a.r * b.r + a.i * b.i) / la, (a.i * b.r - a.r * b.i) / la};
}

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % inf, a = a * a % inf, k >>= 1;
    return rt;
}

inline void fft(in *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    in la, lb;
    for(ri i = 1; i < len; i <<= 1)
        for(ri j = 0; j < len; j += (i << 1))
            for(ri k = j; k < j + i; k ++)
            {
                in w(cos((k - j) * pi / i), tp * sin((k - j) * pi / i));
                la = a[k], lb = w * a[k + i];
                a[k] = la + lb, a[k + i] = la - lb;
            }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i].r /= 1.0 * len, a[i].i /= 1.0 * len;
}

inline void mtt(lo *a, lo *b, lo *c, lo len)
{
    lo x = 1, num = 0;
    while(x <= len)
        x <<= 1, num ++;
    for(ri i = 0; i < len; i ++)
    {
        a[i] %= inf, b[i] %= inf;
        ap[i] = (in){a[i] & mo, a[i] >> 15};
        p[i] = (in){b[i] & mo, b[i] >> 15};
    }
    for(ri i = len; i < x; i ++)
        ap[i] = p[i] = (in){0, 0};
    for(ri i = 0; i < x; i ++)
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    fft(ap, x, 1), fft(p, x, 1);
    for(ri i = 0; i < x; i ++)
        q[i] = p[(x - i) & (x - 1)].conj();
    in g1(0, 2), g2(2, 0);
    for(ri i = 0; i < x; i ++)
        f[i] = (p[i] + q[i]) / g2, g[i] = (p[i] - q[i]) / g1;
    for(ri i = 0; i < x; i ++)
        apf[i] = ap[i] * f[i], apg[i] = ap[i] * g[i];
    fft(apf, x, -1), fft(apg, x, -1);
    for(ri i = 0; i < len; i ++)
        c[i] = ((((lo)(apf[i].r + 0.5)) % inf 
                    + (((lo)(apf[i].i + 0.5) % inf) << 15) % inf
                    + (((lo)(apg[i].r + 0.5) % inf) << 15) % inf
                    + (((((lo)(apg[i].i + 0.5) % inf) << 15) % inf) << 15) % inf) % inf + inf) % inf;
}

void merge(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        b[0] = ksm(a[0], inf - 2); return;
    }
    merge(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x < len)
        x <<= 1, num ++;
    mtt(a, b, c, len), mtt(c, b, d, len);
    for(ri i = 0; i < len; i ++)
        b[i] = (b[i] + b[i]) % inf;
    for(ri i = 0; i < len; i ++)
        b[i] = ((b[i] - d[i]) % inf + inf) % inf;
}

int main()
{
    re(n);
    for(ri i = 0; i < n; i ++)
        re(s[i]);
    lo x = 1;
    while(x < n)
        x <<= 1;
    merge(s, e, x);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", e[i]);
}

多项式对数函数

这个东西原理还是挺简单的,推荐没有学过求导和积分的童鞋去看看高中数学选修2-2的简单定义。
设函数F(x)=lnx,现在有多项式G(x),我们想求H(x) = F(G(x))
根据复合函数求导公式可得,H'(x) = F'(G(x))G'(x),根据高中数学的求导公式我们可以进一步得到H'(x) = \frac{G'(x)}{G(x)},而对于G(x),它的导数也很好求,根据A(x) = x^{a},A'(x) = ax^{a-1}这个式子,我们可以O(n)的处理出G'(x),而分母上我们又可以通过之前的多项式求逆算出来。算出H'(x),因为求导和积分互为逆运算,所以我们根据\int x^{a}dx = \frac{1}{a+1}x^{a+1}可以再积回去,这样我们就求得H(x)了。
luoguP4725【模板】多项式对数函数

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 3e5;

const lo mo = 998244353;

lo n, pos[ms], f[ms], g[ms], ff[ms], fd[ms], gg[ms], po[ms], inv[ms], c[ms];

inline lo ksm(lo x, lo k)
{
    lo a = x, rt = 1;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        a[0] = ksm(b[0], mo - 2); return;
    }
    merge(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? b[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(a, x, 1), ntt(c, x, 1);
    for(ri i = 0; i < x; i ++)
        a[i] = ((2ll - a[i] * c[i] % mo) % mo + mo) % mo * a[i] % mo;
    ntt(a, x, -1);
    for(ri i = len; i < x; i ++)
        a[i] = 0;
}

inline void askln(lo len)
{
    for(ri i = 1; i < len; i ++)//计算导数 
        fd[i - 1] = f[i] * i % mo;
    fd[len - 1] = 0;
    merge(ff, f, len);
    lo x = 1, num = 0;
    while(x <= len)
        x <<= 1, num ++;
    for(ri i = 1; i < x; i ++)
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(fd, x, 1), ntt(ff, x, 1);
    for(ri i = 0; i < x; i ++)
        gg[i] = 1ll * fd[i] * ff[i] % mo;
    ntt(gg, x, -1);
    for(ri i = 1; i < x; i ++)//积分回来,注意乘i的逆元 
        g[i] = gg[i - 1] * ksm(i, mo - 2) % mo;
    g[0] = 0;
}

int main()
{
    re(n); lo inv3 = ksm(3, mo - 2), len = 1, num = 0;
    for(ri i = 0; i < n; i ++)
        re(f[i]);
    while(len <= n)
        len <<= 1, num ++;
    for(ri i = 1; i <= n + n; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));
    for(ri i = 1; i <= n + n; i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    askln(len);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", g[i]);
}

多项式指数函数

现在给你一个多项式A(x),求在mod \ x^{n}意义下的一个多项式B(x),使得B(x) \equiv e^{A(x)} (mod \ x^{n})
这个简单来说就是各个板子的小集合了,你需要求逆求ln……往上面套。
先介绍个中西,泰勒展开,这里照搬下百度百科的原话。
泰勒公式是将一个在x=x_0处具有n阶导数的函数f(x)利用关于(x-x_0)的n次多项式来逼近函数的方法。
若函数f(x)在包含x0的某个闭区间[a,b]上具有n阶导数,且在开区间(a,b)上具有(n+1)阶导数,则对闭区间[a,b]上任意一点x,成立下式:
f(x) = \frac{f(x_0)}{0!}+\frac{f'(x_0)(x-x_0)}{1!}+\frac{f”(x_0)(x-x_0)^2}{2!}+……+\frac{f^{(n)}(x_0)(x-x_0)^n}{2!}+R_nx
这里f^{(n)}(x_0)f(x_0)的n阶导数,R_nx是一个余项。但是今天我们用不到余项,所以先不管他。
这个公式……反正我不会证明,所以只能当一个结论记住。那么这个公式的用处呢?你可以看到。这个公式将原来一个函数化成了导数。
我们先把问题进行转化,现在不是求B(x) = e^{A(x)}吗,那么也就是说
ln(B(x)) – A(x) \equiv 0 (mod \ x^{n})
问题转化成了求函数零点。而求函数零点,有一个大杀器就叫做牛顿迭代,至于多项式怎么迭代呢……
设存在一个函数
C(D(x)) \equiv 0 (mod \ x^n)
假设我们已经求出了
C(D(x)) \equiv 0 (mod x^{\frac{n}{2}})
那么我们利用刚才的泰勒展开,将C(D(x)) \equiv 0 (mod \ x^n)这个式子划开
C(D(x)) \equiv r_nx+\sum_{i=0}^{n}\frac{C^{(i)}(D_0(x))(D(x)-D_0(x))^i}{i!}(mod \ x^n)
其中C^{(i)}(D_0(x))表示的是C(D_0(x))的i阶导数。
因为D(x)-D_0(x)这个东西他们的前\lceil\frac{n}{2}\rceil项是相同,所以平方之后起步就是2\lceil\frac{n}{2}\rceil,在mod \ x^n下会被消掉,也就说这个式子只剩下了前两项,即
C(D(x)) \equiv C(D_0(x)) + C'(D_0(x))(D(x) – D_0(x))(mod \ x^n)
又因为C(D(x)) \equiv 0(mod \ x^n)
所以D(x) \equiv D_0(x) – \frac{C(D(x)}{C'(D(x))} (mod \ x^n)
那么我们回过头来把我们的函数带进去,设
C(B(x)) = ln(B(x)) – A(x)
A(x)这里看做是常数,那么带入公式可得
B(x) \equiv B_0(x)- \frac{C(B_0(x)}{C'(B_0(x))} (mod \ x^n) = B_0(1 – ln(B_0(x)) + A(x))
边界是A(0)=0,类似于多项式求逆,不停递归就好了。
luoguP4726【模板】多项式指数函数

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo mo = 998244353;

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

const lo ms = 3e5;

lo n, pos[ms], c[ms], ad[ms], aa[ms], bb[ms], f[ms], g[ms], lnb[ms], po[ms], inv[ms];

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        a[0] = ksm(b[0], mo - 2); return;
    }
    merge(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? b[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(a, x, 1), ntt(c, x, 1);
    for(ri i = 0; i < x; i ++)
        a[i] = ((2ll - a[i] * c[i] % mo) % mo + mo) % mo * a[i] % mo;
    ntt(a, x, -1);
    for(ri i = len; i < x; i ++)
        a[i] = 0;
    /*if(len == 1)
    {
        a[0] = ksm(b[0], mo - 2); return;
    }
    merge(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? b[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(c, x, 1), ntt(a, x, 1);
    for(ri i = 0; i < x; i ++)
        a[i] = ((2ll - a[i] * c[i] % mo) % mo + mo) % mo * a[i] % mo;
    ntt(a, x, -1);
    for(ri i = len; i < x; i ++)
        a[i] = 0;*/
}

inline void ln(lo *a, lo *b, lo len)
{
    for(ri i = 1; i < len; i ++)
        ad[i - 1] = a[i] * i % mo;
    ad[len - 1] = 0;
    merge(aa, a, len);
    //for(ri i = 0; i < len; i ++)
    //  cout << aa[i] << ' ';
    //cout << '\n';
    lo x = 1, num = 0;
    while(x <= len)
        x <<= 1, num ++;
    for(ri i = 1; i < x; i ++)
        pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(aa, x, 1), ntt(ad, x, 1);
    //for(ri i = 0; i < x; i ++)
    //  cout << aa[i] << ' ';
    //cout << '\n';
    for(ri i = 0; i < x; i ++)
        bb[i] = 1ll * aa[i] * ad[i] % mo;
 //for(ri i = 0; i < x; i ++)
    //  cout << bb[i] << ' ';
    //cout << '\n';
    ntt(bb, x, -1);
    for(ri i = 1; i < x; i ++)
        b[i] = bb[i - 1] * ksm(i, mo - 2) % mo;
    b[0] = 0;
    //for(ri i = 0; i < x; i ++)
    //  cout << b[i] << ' ';
    //cout << '\n';
    for(ri i = 0; i < x; i ++)
        aa[i] = ad[i] = bb[i] = 0;
}

void exp(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        b[0] = 1; return;
    } 
    lo x = 1, num = 0;
    while(x < (len << 1))
        x <<= 1, num ++;
    exp(a, b, (len + 1) >> 1);
    ln(b, lnb, len);
    //for(ri i = 0; i < (len << 1); i ++)
    //  cout << lnb[i] << ' ';
    //cout << '\n';
    lnb[0] = ((a[0] + 1 - lnb[0]) % mo + mo) % mo;
    for(ri i = 1; i < len; i ++)
        lnb[i] = ((a[i] - lnb[i]) % mo + mo) % mo;
    ntt(lnb, x, 1), ntt(b, x, 1);
    for(ri i = 0; i < x; i ++)
        b[i] = 1ll * b[i] * lnb[i] % mo;
    ntt(b, x, -1);
    for(ri i = len; i < x; i ++)
        b[i] = lnb[i] = 0;
}

int main()
{
    re(n);
    for(ri i = 0; i < n; i ++)
        re(f[i]);
    lo inv3 = ksm(3, mo - 2), len = 1;
    while(len <= n)
        len <<= 1;
    for(ri i = 1; i <= n + n; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1));
    for(ri i = 1; i <= n + n; i <<= 1)
        inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    exp(f, g, len);
    for(ri i = 0; i < n; i ++)
        printf("%lld ", g[i]);
}

多项式开方

其实原理和刚才差不多,只不过从指数函数换成了二次函数而已。设F(x)^{2} = G(x),移项可得F(x)^{2} – G(x) = 0
继续带入刚才的结论,设H(F(x)) = F(x)^{2} – G(x),那么
F(x) \equiv F_0(x) – \frac{H(F_0(x))}{H'(F_0(x))} \equiv \frac{F_0(x)^2 + G(x)}{2F_0(x)} (mod \ x^n),然后还是递归算这个东西就好了,不过如果G(x)常数项不为1,可能是要计算二次剩余……二次剩余咕咕咕,这里是多项式qwq。
luogu并没有多项式开根的例题,所以只能放个比较接近板子的计数题:CF438E The Child and Binary Tree
这个题因为生成函数没学好,写错别打我……
首先我们设一个点的方案生成函数是H(x) = \sum_{i \in c}x^ic就是题目中的c,那么在生成树上的一个点方案F(x) = 1+H(x)F(x)^21是这里是空树,左右儿子方案数在x能到无穷的时候是能等价于F(x)的,那么解方程
F(x) = \frac{1\pm \sqrt{1-4H(x)}}{2H(x)}
因为F(0) = 1,所以上面为正号,为了方便,我们再化一步
F(x) = \frac{2}{1-\sqrt{1-4H(x)}},这样我们只需要开一次求一次逆,比刚才要简单点。

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<utility>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#include<set>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo mo = 998244353, inv2 = 499122177;

inline lo ksm(lo x, lo k)
{
    lo rt = 1, a = x;
    while(k)
        rt = rt * ((k & 1) ? a : 1) % mo, a = a * a % mo, k >>= 1;
    return rt;
}

const lo ms = 3e5;

lo n, m, lx, pos[ms], po[ms], inv[ms], c[ms], d[ms], bd[ms], h[ms], hh[ms];

inline void ntt(lo *a, lo len, lo tp)
{
    for(ri i = 1; i < len; i ++)
        if(i < pos[i])
            swap(a[i], a[pos[i]]);
    lo la, lb, invlen = ksm(len, mo - 2);
    for(ri i = 1; i < len; i <<= 1)
    {
        lo wn = (tp == 1) ? po[i] : inv[i];
        for(ri j = 0; j < len; j += (i << 1))
        {
            lo w = 1;
            for(ri k = j; k < j + i; k ++)
            {
                la = a[k], lb = w * a[k + i] % mo;
                w = w * wn % mo, a[k] = (la + lb) % mo, a[k + i] = ((la - lb) % mo + mo) % mo;
            }
        }
    }
    if(tp == -1)
        for(ri i = 0; i < len; i ++)
            a[i] = a[i] * invlen % mo;
}

void merge(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        b[0] = ksm(a[0], mo - 2); return;
    }
    merge(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x <= len)
        x <<= 1, num ++;
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? a[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(c, x, 1), ntt(b, x, 1);
    for(ri i = 0; i < x; i ++)
        b[i] = ((2ll - b[i] * c[i] % mo) % mo + mo) % mo * b[i] % mo;
    ntt(b, x, -1);
    for(ri i = len; i < x; i ++)
        b[i] = 0;
}

void sqr(lo *a, lo *b, lo len)
{
    if(len == 1)
    {
        b[0] = a[0]; return;
    }
    sqr(a, b, (len + 1) >> 1); lo x = 1, num = 0;
    while(x <= len)
        x <<= 1, num ++;
    merge(b, bd, len);
    for(ri i = 0; i < x; i ++)
        c[i] = (i < len) ? a[i] : 0, pos[i] = ((pos[i >> 1] >> 1) | ((i & 1) << (num - 1)));
    ntt(c, x, 1), ntt(bd, x, 1);
    for(ri i = 0; i < x; i ++)
        bd[i] = c[i] * bd[i] % mo;
    ntt(bd, x, -1);
    for(ri i = 0; i < len; i ++)
        b[i] = ((b[i] + bd[i]) % mo) * inv2 % mo;
    for(ri i = 0; i < x; i ++)
        c[i] = bd[i] = 0;
}

int main()
{
    re(n), re(m);
    for(ri i = 1; i <= n; i ++)
        re(lx), h[lx] ++;
    h[0] = 1;
    lo len = 1, num = 0, inv3 = ksm(3, mo - 2);
    while(len <= m)
        len <<= 1, num ++;
    for(ri i = 1; i <= len; i <<= 1)
        po[i] = ksm(3, (mo - 1) / (i << 1)), inv[i] = ksm(inv3, (mo - 1) / (i << 1));
    for(ri i = 1; i < len; i ++)
        h[i] = (mo - (h[i] << 2)) % mo;
    sqr(h, hh, len);
    for(ri i = 0; i < len; i ++)
        h[i] = 0;
    hh[0] = (hh[0] + 1) % mo;
    merge(hh, h, len);
    for(ri i = 0; i <= m; i ++)
        h[i] = (h[i] << 1) % mo;
    for(ri i = 1; i <= m; i ++)
        printf("%lld\n", h[i]);
}

【算法】动态dp&KD-tree

动态dp推荐题解 P4643 【【模板】动态dp】。这个博客除了他矩阵写的不好看的意外其他都海星。
我们写成矩阵以后答案就是整个区间所有矩阵乘起来,所以ldp和dp之间界限并不明显,可以转化。注意全局平衡二叉树还是老老实实的左儿子表示左边。矩阵其实这么写比较好看
\begin{gathered} \begin{bmatrix} +\infty & ldp_{i, 0} \\ ldp_{1} & ldp_{i, 1} \end{bmatrix} \begin{bmatrix} +\infty & dp_{i, 0} \\ dp_{i}{1} & dp_{i, 1} \end{bmatrix} \end{gathered}
剩下的按照上面那个博客来就行了,讲的很清楚。
贴代码。
luogu P4719 动态dp(这个我自己矩阵都写得看不懂)

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const int ms = 2e5 + 20;

int n, m, rot, lx, ly, tot = -1, sta[ms], lsz[ms], fa[ms], son[ms], v[ms], head[ms], sz[ms], ch[ms][2];

struct in
{
    int to, ne;
}ter[ms << 1];

inline void build(int f, int l)//建边
{
    ter[++ tot] = (in){l, head[f]}, head[f] = tot;
    ter[++ tot] = (in){f, head[l]}, head[l] = tot;
}

void init(int no, int f)
{
    sz[no] = 1; int p = 0;
    for(ri i = head[no]; i >= 0; i = ter[i].ne)
    {
        int to = ter[i].to;
        if(to == f)
            continue;
        init(to, no), sz[no] += sz[to], son[no] = (sz[son[no]] > sz[to]) ? son[no] : to;
    }
}

struct mar
{
    int a[2][2];
    mar()
    {
        memset(a, -63, sizeof(a));
    }
    mar(int x)
    {
        a[0][0] = a[1][1] = 0, a[1][0] = a[0][1] = -1e9 - 7;
    }
    inline int mx()
    {
        return max(max(a[0][0], a[0][1]), max(a[1][0], a[1][1]));//这里其实只需要考虑00和01
    }
    inline void print()
    {
        for(ri i = 0; i < 2; i ++, printf("\n"))
            for(ri j = 0; j < 2; j ++)
                printf("%d ", a[i][j]);
    }
}w[ms], ans[ms];

inline mar operator * (mar a, mar b)
{
    mar c; 
    for(ri k = 0; k < 2; k ++)
        for(ri i = 0; i < 2; i ++)
            for(ri j = 0; j < 2; j ++)
                c.a[i][j] = max(c.a[i][j], a.a[i][k] + b.a[k][j]);
    return c;
}

bool flag[ms];

inline void init1(int no, int to)
{
    w[no].a[1][0] += ans[to].mx(), w[no].a[0][0] = w[no].a[1][0];//关于所有不选的情况的转移
    w[no].a[0][1] += max(ans[to].a[0][0], ans[to].a[1][0]), fa[to] = no;//选的话儿子强制不选max没有意义
}

inline void up(int x)
{
    ans[x] = ans[ch[x][0]] * w[x] * ans[ch[x][1]];//左儿子在原来的树上是靠下的,右儿子靠上
    printf("%d\n", x);
    printf("ans:\n"); ans[x].print();
    printf("w:\n"); w[x].print();
}

inline bool dir(int x)
{
    return ch[fa[x]][0] != x && ch[fa[x]][1] != x;
}

int bstbuild(int l, int r)//建一个静态的树出来
{
    if(l > r)
        return 0;
    int lin = 0;
    for(ri i = l; i <= r; i ++)
        lin += lsz[sta[i]];
    for(ri i = l, j = lsz[sta[l]]; i <= r; i ++, j += lsz[sta[i]])
        if((j << 1) >= lin)//每次大小少一半,所以logn
        {
            ch[sta[i]][1] = bstbuild(l, i - 1), ch[sta[i]][0] = bstbuild(i + 1, r);
            fa[ch[sta[i]][1]] = fa[ch[sta[i]][0]] = sta[i], up(sta[i]); return sta[i];
        }
}

int rebuild(int no)//先剖链,剖成一个个的链然后再去转移
{
    for(ri i = no; i; i = son[i])
        flag[i] = 1;
    for(ri i = no; i; i = son[i])
        for(ri j = head[i]; j >= 0; j = ter[j].ne)
            if(flag[ter[j].to] == 0)
                init1(i, rebuild(ter[j].to));//得到那个临时数组
    int ta = 0;
    for(ri i = no; i; i = son[i])
        sta[++ ta] = i;
    for(ri i = no; i; i = son[i])//处理出这个子链,按重心
        lsz[i] = sz[i] - sz[son[i]];
    return bstbuild(1, ta);
}

inline void change(int no, int val)
{
    w[no].a[0][1] += val - v[no], v[no] = val;
    for(ri i = no; i; i = fa[i])
        if(dir(i) && fa[i])
        {
            w[fa[i]].a[0][0] -= ans[i].mx(), w[fa[i]].a[1][0] = w[fa[i]].a[0][0];
            w[fa[i]].a[0][1] -= max(ans[i].a[0][0], ans[i].a[1][0]), up(i);
            w[fa[i]].a[0][0] += ans[i].mx(), w[fa[i]].a[1][0] = w[fa[i]].a[0][0];
            w[fa[i]].a[0][1] += max(ans[i].a[0][0], ans[i].a[1][0]);
        }
        else
            up(i);
}

//ldp0 = simga lightson max(dp0, dp1)
//ldp1 = simga lightson dp0
/*
对于ans数组,00是dp0,01是dp1(剩下那俩不重要)
对于w数组,00 10是ldp0,01是ldp1,11负无穷
*/

int main()
{
    re(n), re(m); memset(head, -1, sizeof(head));
    w[0] = ans[0] = mar(1);
    for(ri i = 1; i <= n; i ++)
        re(v[i]);
    for(ri i = 1; i < n; i ++)
        re(lx), re(ly), build(lx, ly);
    init(1, 0);
    for(ri i = 1; i <= n; i ++)
        w[i].a[0][1] = v[i], w[i].a[0][0] = w[i].a[1][0] = 0;
    rot = rebuild(1);
    while(m --)
        re(lx), re(ly), change(lx, ly), printf("%d\n", ans[rot].mx());
    system("pause");
}

luogu P4751 动态dp【加强版】
这个题目就是刚才强制在线

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const int ms = 2e6 + 20;

int n, m, rot, lx, ly, tot = -1, sta[ms], lsz[ms], fa[ms], son[ms], v[ms], head[ms], sz[ms], ch[ms][2];

struct in
{
    int to, ne;
}ter[ms << 1];

inline void build(int f, int l)//建边
{
    ter[++ tot] = (in){l, head[f]}, head[f] = tot;
    ter[++ tot] = (in){f, head[l]}, head[l] = tot;
}

void init(int no, int f)
{
    sz[no] = 1; int p = 0;
    for(ri i = head[no]; i >= 0; i = ter[i].ne)
    {
        int to = ter[i].to;
        if(to == f)
            continue;
        init(to, no), sz[no] += sz[to], son[no] = (sz[son[no]] > sz[to]) ? son[no] : to;
    }
}

struct mar
{
    int a[2][2];
    mar()
    {
        memset(a, -63, sizeof(a));
    }
    mar(int x)
    {
        a[0][0] = a[1][1] = 0, a[1][0] = a[0][1] = -1e9 - 7;
    }
    inline int mx()
    {
        return max(max(a[0][0], a[0][1]), max(a[1][0], a[1][1]));
    }
}w[ms], ans[ms];

inline mar operator * (mar a, mar b)
{
    mar c; 
    for(ri k = 0; k < 2; k ++)
        for(ri i = 0; i < 2; i ++)
            for(ri j = 0; j < 2; j ++)
                c.a[i][j] = max(c.a[i][j], a.a[i][k] + b.a[k][j]);
    return c;
}

bool flag[ms];

inline void init1(int no, int to)
{
    w[no].a[1][0] += ans[to].mx(), w[no].a[0][0] = w[no].a[1][0];//关于所有不选的情况的转移
    w[no].a[0][1] += max(ans[to].a[0][0], ans[to].a[1][0]), fa[to] = no;
}

inline void up(int x)
{
    ans[x] = ans[ch[x][0]] * w[x] * ans[ch[x][1]];
}

inline bool dir(int x)
{
    return ch[fa[x]][0] != x && ch[fa[x]][1] != x;
}

int bstbuild(int l, int r)
{
    if(l > r)
        return 0;
    int lin = 0;
    for(ri i = l; i <= r; i ++)
        lin += lsz[sta[i]];
    for(ri i = l, j = lsz[sta[l]]; i <= r; i ++, j += lsz[sta[i]])
        if((j << 1) >= lin)
        {
            ch[sta[i]][1] = bstbuild(l, i - 1), ch[sta[i]][0] = bstbuild(i + 1, r);
            fa[ch[sta[i]][1]] = fa[ch[sta[i]][0]] = sta[i], up(sta[i]); return sta[i];
        }
}

int rebuild(int no)//先剖链,剖成一个个的链然后再去转移
{
    for(ri i = no; i; i = son[i])
        flag[i] = 1;
    for(ri i = no; i; i = son[i])
        for(ri j = head[i]; j >= 0; j = ter[j].ne)
            if(flag[ter[j].to] == 0)
                init1(i, rebuild(ter[j].to));//得到那个临时数组
    int ta = 0;
    for(ri i = no; i; i = son[i])
        sta[++ ta] = i;
    for(ri i = no; i; i = son[i])
        lsz[i] = sz[i] - sz[son[i]];
    return bstbuild(1, ta);
}

inline void change(int no, int val)
{
    w[no].a[0][1] += val - v[no], v[no] = val;
    for(ri i = no; i; i = fa[i])
        if(dir(i) && fa[i])
        {
            w[fa[i]].a[0][0] -= ans[i].mx(), w[fa[i]].a[1][0] = w[fa[i]].a[0][0];
            w[fa[i]].a[0][1] -= max(ans[i].a[0][0], ans[i].a[1][0]), up(i);
            w[fa[i]].a[0][0] += ans[i].mx(), w[fa[i]].a[1][0] = w[fa[i]].a[0][0];
            w[fa[i]].a[0][1] += max(ans[i].a[0][0], ans[i].a[1][0]);
        }
        else
            up(i);
}

int main()
{
    re(n), re(m); memset(head, -1, sizeof(head));
    w[0] = ans[0] = mar(1);
    for(ri i = 1; i <= n; i ++)
        re(v[i]);
    for(ri i = 1; i < n; i ++)
        re(lx), re(ly), build(lx, ly);
    init(1, 0);
    for(ri i = 1; i <= n; i ++)
        w[i].a[0][1] = v[i], w[i].a[0][0] = w[i].a[1][0] = 0;
    rot = rebuild(1); int lans = 0;
    for(ri i = 1; i <= m; i ++)
        re(lx), re(ly), lx = (i > 1) ? (lx ^ lans) : lx, change(lx, ly), lans = ans[rot].mx(), printf("%d\n", ans[rot].mx());
    system("pause");
}

之后抽空写noip2018题解的时候再放保卫王国
kdtree先记个复杂度,\(O(n^{\frac{k+k-1}{k}})\)
其实构造挺好理解的,每次取中间,k维轮着来就是了
luogu P4631 [APIO2018] Circle selection 选圆圈

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h1 = buf, *h2 = buf;
    return h1 == h2 && (h2 = (h1 = buf) + fread(buf, 1, 100000, stdin), h1 == h2) ? EOF : *h1 ++;
}

typedef long long lo;

typedef long double ld;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 3e5 + 10;

const ld ex = acos(-1.0) / 5;

lo n, no, lx, ly, lz, opt, rot, ch[ms][2], ans[ms];

ld lp[ms][2], rp[ms][2];

struct node
{
    ld p[2], r; lo pos;
    inline bool operator < (const node &a) const
    {
        return p[opt] < a.p[opt];
    }
}poi[ms], ter[ms];

inline void up(lo w)
{
    for(ri i = 0; i < 2; i ++)
    {
        lp[w][i] = ter[w].p[i] - ter[w].r, rp[w][i] = ter[w].p[i] + ter[w].r;
        if(ch[w][0] > 0)
        {
            lp[w][i] = min(lp[w][i], lp[ch[w][0]][i]);
            rp[w][i] = max(rp[w][i], rp[ch[w][0]][i]);
        }
        if(ch[w][1] > 0)
        {
            lp[w][i] = min(lp[w][i], lp[ch[w][1]][i]);
            rp[w][i] = max(rp[w][i], rp[ch[w][1]][i]);
        }
    }
}

lo build(lo l, lo r, lo f)
{
    if(l > r)
        return 0;
    lo mid = (l + r) >> 1; opt = f;
    nth_element(poi + l, poi + mid, poi + r + 1);//从中劈开 
    ter[mid] = poi[mid], ch[mid][0] = build(l, mid - 1, f ^ 1), ch[mid][1] = build(mid + 1, r, f ^ 1);//每次k维度轮着来 
    up(mid); return mid;
}

inline bool cmp(node a, node b)
{
    if(fabs(a.r - b.r) < 1e-6)
        return a.pos < b.pos;
    return a.r > b.r;
}

inline bool check(lo w, ld x, ld y, ld r)
{
    ld xx = 0; 
    if(x < lp[w][0] || x > rp[w][0])
        xx += min(pow(lp[w][0] - x, 2), pow(rp[w][0] - x, 2));
    if(y < lp[w][1] || y > rp[w][1])
        xx += min(pow(lp[w][1] - y, 2), pow(rp[w][1] - y, 2));
    return r * r - xx < 1e-6;
}

void change(lo w, ld x, ld y, ld r)//暴力遍历 
{
    if(w == 0 || check(w, x, y, r))//越界就不遍历了 
        return;
    if(ans[ter[w].pos] <= 0 && pow(ter[w].p[0] - x, 2) + pow(ter[w].p[1] - y, 2)  - pow(ter[w].r + r, 2) < 1e-6)
        ans[ter[w].pos] = lx;
    change(ch[w][0], x, y, r), change(ch[w][1], x, y, r);
}

int main()
{
    re(n);
    for(ri i = 1; i <= n; i ++)
    {
        re(lx), re(ly), re(lz);
        poi[i].p[0] = lx * cos(ex) - ly * sin(ex);
        poi[i].p[1] = lx * sin(ex) + ly * cos(ex);
        poi[i].r = lz, poi[i].pos = i;
    }
    rot = build(1, n, 0); sort(poi + 1, poi + 1 + n, cmp);
    for(ri i = 1; i <= n; i ++)
        if(!ans[poi[i].pos])
            ans[poi[i].pos] = lx = poi[i].pos, change(rot, poi[i].p[0], poi[i].p[1], poi[i].r);
    for(ri i = 1; i <= n; i ++)
        printf("%lld ", ans[i]);
}

【算法】斜率优化

这个东西其实挺有意思的,之前学了好几次,这次专门来写写
斜率优化顾名思义,跟斜率有关系。我理解的斜率优化就是,将你的dp式子转化成一个函数形式。这个函数应该多半是一次函数。然后你的目的就是去根据题目去按照斜率寻找转移状态。
接下来放例题
luoguP2365任务安排
这个题目我们首先要用一个小技巧。我们现在这个dp很蛋疼,因为我不知道转移到i这个点的时候我划分成了几块。所以我们可能会考虑\(dp[i][j]\)这么一个状态来表示目前考虑到第i个人j块的最小代价。这样我们转移就成\(n^{3}\)了。
但是事实上这并不需要。我们不考虑之前有多少块了,而是这么想,\(dp[i]\)表示目前考虑完前i个,并且i是最后一块结尾的最小代价。转移的时候我们这么转移,考虑新生成的一块对后面所有点产生的影响,比如我们从\(j\)这个点开始划分块,一直划分到\(i\),那么你考虑我们目前对于j后面那些点都增加了s时间。
所以这个地方我就只需要一个dp状态\(dp[i]\)就可以了。这样显然是没有后效性的。我们复杂度也随之降到了\(n^{2}\)。对于luogu这个题目来说这是一个足够的复杂度了。但是在bzoj还有一个加强了两轮的版本,其中加强的一部分是要求复杂度在\(nlogn\)及其以下的复杂度。所以我们首先需要从这里优化复杂度。
首先我们对时间和花费求个前缀和,这样便于我们接下来计算
我们考虑现在这个dp方程
\(dp[i] = min(dp[j] + s(c[n] – c[j]) + t[i](c[i] – c[j]))\)
s和题意一致,c[i]表示花费的前缀和,t[i]表示所消耗是时间的前缀和,那么对于每一个j来说,我们都可以移项变形,那么方程变成了
\(dp[j] = dp[i] – s(c[n] – c[j]) – t[i](c[i] – c[j]) = (t[i] + s)c[j] + dp[i] – t[i] * c[i] – s * c[n]\)
我们观察这个最后的式子,是不是有点类似于\(y = kx + b\)
也就是说,在这个式子里面,\(dp[j]\)是\(y\),\((t[i] + s)\)是斜率
\(dp[i] – t[i] * c[i] – s * c[n]\)这是一个常数
那么我要让\(dp[i]\)尽可能小,也就说这个常数要尽可能的小。
那么你手画下就会发现这么一个事情,对于我新生成的这条直线,也就是\(dp[j]\)的表达式,我把这条直线向上不停的移动,我最早碰到的那个点,他就是\(b\)最小的一个


也就是这两个图画的这样。我们继续手玩会进一步得出结论,我们按照\(x\)排序,这个时候你会发现,设我们这个直线的斜率为\(k_{0}\)那么也就是\(k_{1} < k_{0} < k_{2}\)。稍微思考下你就会发现这是个单调递增的函数
大概就是长这样



这样我们就可以发现,我只需要维护一个下凸壳就可以了。因为这个地方\(x\)具有单调性,所以我们可以直接队列维护,考虑到\(c[i]\)也是如此,我们可以发现在这个题目里面甚至可以直接一个单调队列就行了,每次转移队首的。因为你看刚才那个不等式,那些\(k_{1}\)现在如果用不到,以后也不会,所以可以直接不考虑。
然后这个题目就做完了。放一下dalao的博客,图也是从dalao那里借来的,已经说过了qwq斜率优化学习笔记

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int n, s, dp[5050], ta, he, t[5050], c[5050], q[5050];

int main()
{
    re(n), re(s);
    for(ri i = 1; i <= n; i ++)
        re(t[i]), re(c[i]), t[i] += t[i - 1], c[i] += c[i - 1];
    q[he = ta = 1] = 0;
    for(ri i = 1; i <= n; i ++)
    {
        while(he < ta && dp[q[he + 1]] - dp[q[he]] <= (t[i] + s) * (c[q[he + 1]] - c[q[he]]))
            he ++;
        dp[i] = dp[q[he]] + s * (c[n] - c[q[he]]) + t[i] * (c[i] - c[q[he]]);
        while(he < ta && (dp[i] - dp[q[ta]]) * (c[q[ta]] - c[q[ta - 1]]) <= (dp[q[ta]] - dp[q[ta - 1]]) * (c[i] - c[q[ta]]))
            ta --;
        q[++ ta] = i;
    }
    printf("%d", dp[n]);
    system("pause");
}

但是刚才说过了,这个题目还有进击版,bzoj2726任务安排
这个题目首先……你数组按30w吧,时间可正可负,并且所有的变量longlong都存的下。
那么现在\(t[i]\)不单调了,但是\(c[i]\)还是单调的
还记得刚才我说的\(nlog_{n}\)吗,现在确实截的直线的斜率并不单调,但我们依然可以用单调队列维护下凸包,然后二分去查找最早满足那个不等式的地方就行了。

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdlib>
#include<cstdio>
#include<vector>
#include<queue>
#include<stack>
#include<cmath>
#include<map>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 300030;

lo n, s, t[ms], c[ms], dp[ms], q[ms], he, ta;

inline lo find(lo x)
{
    lo l = he, r = ta;
    if(l == r)
        return q[he];
    while(l < r)
    {
        lo mid = (l + r) >> 1;
        if(dp[q[mid + 1]] - dp[q[mid]] > x * (c[q[mid + 1]] - c[q[mid]]))
            r = mid;
        else
            l = mid + 1;
    }
    return q[l];
}

int main()
{
    re(n), re(s);
    for(ri i = 1; i <= n; i ++)
        re(t[i]), re(c[i]), t[i] += t[i - 1], c[i] += c[i - 1];
    he = ta = 1;
    for(ri i = 1; i <= n; i ++)
    {
        lo pos = find(s + t[i]);
        dp[i] = dp[pos] + s * (c[n] - c[pos]) + t[i] * (c[i] - c[pos]);
        while(he < ta && (dp[i] - dp[q[ta]]) * (c[q[ta]] - c[q[ta - 1]]) <= (dp[q[ta]] - dp[q[ta - 1]]) * (c[i] - c[q[ta]]))
            ta --;
        q[++ ta] = i;
    }
    printf("%lld", dp[n]); system("pause");
}

还是dalao的博客2018.09.05 bzoj2726: [SDOI2012]任务安排(斜率优化dp+二分)

【codevs2382】挂缀

这个题……服气了
我们假设之前已经选择了\(w\)重量的珠子,现在我们要考虑a挂b上比较优还是b挂a上比较优
b挂a上的时候,\(c_{a} \geq w_{b} + w, c_{b} \geq w\)
a挂b上的时候,\(c_{b} \geq w_{a} + w, c_{a} \geq w\)
所以如果a挂b上比较优的话\(c_{b} – (w_{a} + w) \geq c_{a} – (w_{b} + w)\),既\(c_{b} + w_{b} \geq c_{a} + w_{a}\)
反之\(c_{b} – (w_{a} + w) \leq c_{a} – (w_{b} + w)\),既\(c_{a} + w_{a} \geq c_{b} + w_{b}\)
我们要留尽可能多的承重力挂别的珠缀,这样我们降序排序即可,但是并不好处理。所以我们正着排,如果\(c_{a} \geq w\),就直接往上挂,否则从里面调出一个最重的踢出去就行了

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

lo n, w, lx, ly, ans;

struct in
{
    lo c, w, s;
}ter[200020];

inline bool cmp(in a, in b)
{
    return a.s < b.s;
}

priority_queue <lo> qwq;

int main()
{
    re(n);
    for(ri i = 1; i <= n; i ++)
        re(lx), re(ly), ter[i] = (in){lx, ly, lx + ly};
    sort(ter + 1, ter + 1 + n, cmp);
    for(ri i = 1; i <= n; i ++)
    {
        if(w <= ter[i].c)
            w += ter[i].w, ans ++, qwq.push(ter[i].w);
        else
        {
            lo qaq = qwq.top();
            if(qaq > ter[i].w)
                w += ter[i].w - qaq, qwq.pop(), qwq.push(ter[i].w); 
        }
    }
    printf("%lld\n%lld", (lo)qwq.size(), w);
}

【算法】博弈论

已经看了好几次博弈论了,前几次没记最后都忘了,所以这次试着记一下,以后忘了还可以顺着记忆找回来
日常安利 [学习笔记] (博弈论)Nim游戏和SG函数
博弈论一般都是从nim游戏开始的,这是博弈论中最为经典的问题之一。一般都是给你n堆石子,两个人轮流去拿一堆里面的石子,最少1个最多全拿走,然后问你先手必胜还是必败
这个东西是有结论的,我们可以把每一堆作为一个子游戏去考虑的话,分别求出他们的sg函数,之后求一个异或和,如果异或和为0必败,异或和不为0必胜
为什么?我们这么考虑,假设有n堆石子,现在轮到后手操作,之前先手把异或和变为了非0,既异或和\(x \neq 0\),那么一定存在一堆石子\(a_{i}\)的,他的最高位和\(x\)最高位都为1,那么我们可以让这堆石子变小,变成\(a_{i} xor x\)(\(a_{i} xor x\)一定小于\(a_{i}\)的),那么我们既消掉了最高位,又消掉了后面那些1,所以异或和又变为了0。因此无论怎么改变对于先手异或和都为0,这个过程肯定不是无限的,所以最后到先手手上的时候一定必败
那么sg函数是什么
sg函数就是说给你一些数,其中最小的没有出现的非负整数即为sg函数
那么对于一个空集,\(sg(x) = 0\)
因此对于\(sg(x) \neq 0\)的情况来说,在他的集合里面一定存在一个数为0;反之亦然
这和博弈论有什么关系
引入接下来两个定义
P-position:在当前的局面下,先手必败
N-position:在当前的局面下,先手必胜
那么我们将nim游戏看成一张图,对于每次转移都建边,sg函数的集合就是他能到达的点的sg值。对于一个必胜态,一定sg值不为0,否则为0
那么结合上面的来看我们可以通过在这个图上移动实现必胜态必败态之间的转换
这就是sg值和博弈论的关系
那么sg值怎么求?
对于从1-m的转移,\(sg(x)=x%(m+1\))
对于全部都能取的情况,\(sg(x)=x\)
对于不连续的转移,我们需要背模板
先写一个第二种情况的(其实与第一种情况是一样的)
luogu P2197 nim游戏

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int t, n, lx;

int main()
{
    re(t);
    while(t --)
    {
        re(n), lx = 0;
        for(ri i = 1, j; i <= n; i ++)
            re(j), lx ^= j;
        if(lx > 0)
            printf("Yes\n");
        else
            printf("No\n");
    }
}

除了nim游戏还有一些模型,比如二分图博弈
二分图博弈的思想就是先将博弈转化成二分图最大匹配,这个时候如果二分图完美匹配的话,先手必胜,否则先手就可以通过选择匹配点获得胜利
难就难在所有可能匹配点的都是答案,我习惯写dinic来判断二分图最大匹配,这个时候怎么做?
从源点出发找出所有源点能到达的,并且本来与源点也相连的点,汇点也这么干一次。
分情况讨论,当源点直接和它有边相连,那么肯定是不一定在最大匹配的,还有一个情况,就是通过来回交错到达的,这里我们考虑从s集合走到t集合,这里一定是个非匹配边,从t走到s,这里一定是匹配边的反向边。那么也就说可以非匹配->匹配->……这样走。这里从s集合内一个点走到另一个点一定是偶数条边,也就说他们之间可以互换了……所以不一定在。
luogu P4055 [JSOI2009] 游戏

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int n, m, t, tot = -1, cnt, head[10010], he[10010], pos[110][110], px[10010], py[10010];

int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1}, de[10010], q[10010];

int ax[10010], ay[10010];

char lx;

bool mmp[110][110], flag[10010], col[10010];

struct in
{
    int to, ne, co;
}ter[100010];

inline void build(int f, int l, int c)
{
    ter[++ tot] = (in){l, head[f], c}, head[f] = tot;
    ter[++ tot] = (in){f, head[l], 0}, head[l] = tot;
}

inline bool bfs()
{
    int hh = 0, tt = 0;
    q[0] = 0, memset(de, 0, sizeof(de)), de[0] = 1;
    while(hh <= tt)
    {
        int qaq = q[hh ++];
        for(ri i = head[qaq]; i >= 0; i = ter[i].ne)
        {
            int to = ter[i].to;
            if(ter[i].co > 0 && de[to] == 0)
                q[++ tt] = to, de[to] = de[qaq] + 1;
        }
    }
    return de[t] > 0;
}

int dfs(int no, int fl)
{
    if(no == t)
        return fl;
    for(ri &i = he[no]; i >= 0; i = ter[i].ne)
    {
        int to = ter[i].to;
        if(ter[i].co > 0 && de[to] == de[no] + 1)
        {
            int rt = dfs(to, min(fl, ter[i].co));
            if(rt > 0)
            {
                ter[i].co -= rt, ter[i ^ 1].co += rt; return rt;
            }
        }
    }
    return 0;
}

void dfss(int no, int f)
{
    flag[no] = 1;
    if(col[no] == f)
        ax[++ cnt] = no;
    for(ri i = head[no]; i >= 0; i = ter[i].ne)
    {
        int to = ter[i].to;
        if(ter[i].co == f && flag[to] == 0)
            dfss(to, f);
    }
}

int main()
{
    re(n), re(m), memset(head, -1, sizeof(head));
    for(ri i = 1; i <= n; i ++)
    {
        while((lx = getchar()) != '#' && lx != '.');
        for(ri j = 1; j <= m; j ++)
        {
            mmp[i][j] = (lx == '.');
            if(mmp[i][j] > 0)
                pos[i][j] = ++ cnt, px[cnt] = i, py[cnt] = j;
            lx = getchar();
        }
    }
    t = ++ cnt;
    for(ri i = 1; i <= n; i ++)
        for(ri j = 1; j <= m; j ++)
            if(mmp[i][j] == 1)
            {
                if((i + j) & 1)
                {
                    build(0, pos[i][j], 1), col[pos[i][j]] = 1;
                    for(ri k = 0; k < 4; k ++)
                    {
                        int tx = i + dx[k], ty = j + dy[k];
                        if(tx < 1 || tx > n || ty < 1 || ty > m || mmp[tx][ty] == 0)
                            continue;
                        build(pos[i][j], pos[tx][ty], 1);
                    }
                }
                else
                    build(pos[i][j], t, 1);
            }
    int nu = 0, x = 0;
    while(bfs())
    {
        memcpy(he, head, sizeof(he));
        while(x = dfs(0, 1000000007))
            nu += x;
    }
    if((nu << 1) == cnt - 1)
    {
        printf("LOSE"); return 0;
    }
    printf("WIN\n");
    cnt = 0, col[t] = 1;
    dfss(0, 1), memset(flag, 0, sizeof(flag)), dfss(t, 0);
    sort(ax + 1, ax + 1 + cnt);
    for(ri i = 1; i <= cnt; i ++)
        if(ax[i] != 0 && ax[i] != t)
            printf("%d %d\n", px[ax[i]], py[ax[i]]);
}

博弈论还可以打表,luogu P2148 [SDOI2009]E&D
这个题目首先你可以暴力去算,对于a+b来说,他一定能够化成c+d的形式,我们按照按照a来算,然后每次枚举被拆成了哪两个数,异或起来,然后这样得到了sg值集合。我们发现,a+b异或起来之后的答案,就是从最低位开始数第一个0出现在第几位上?所以我们可以logn去算?完了?
整个题目就是通过打表找规律实现的?蛇皮操作啊

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int t, n, lx, ly, ans;

inline int mex(int x)
{
    int rt = 0;
    while(x & 1)
        x >>= 1, rt ++;
    return rt;
}

int main()
{
    re(t);
    while(t --)
    {
        re(n), ans = 0;
        for(ri i = 1; i <= n; i += 2)
            re(lx), re(ly), ans ^= mex((lx - 1) | (ly - 1));
        printf("%s\n", ans ? "YES" : "NO");
    }
}

博弈论还可以脑洞清奇一点,不使用sg函数。luogu P4101 [HEOI2014] 人人尽说江南好
这个题目脑洞实在是太过于清奇了,清奇的妙不可言
首先可以这么考虑,合并次数为奇数先手必胜,偶数后手必胜,那么两个人都会尽可能向着自己想要的方向去发展,结果就是……这两个人把比赛拉到最长(感性理解下)
之后我们要考虑,如果最长的合并次数是偶数次,那么一定后手能必胜
分开讨论下,当\(n \leq m\)时,这种情况最后肯定能合成一堆qwq,我们管这个比较大的堆叫【大堆】,假如说现在轮到先手操作,先手还没动,这个时候最长合并次数为偶数次,那么先手有两种可能性,把后面一个堆丢进大堆里面,这样后手再丢一个小堆进去,或者把后面两个堆合成一个,那么后手就可以把这个合成的直接丢进去。无论怎么做,后手都能保证每轮完了之后,大堆的石子会增加两个,那么合并次数也会-2,一直保持为偶数。直到最后先手合无可合。
那么\(n > m\)呢?我们假设\(n = m + 1\),并且m是奇数,最后合成次数最多的情况一定是\((m, 1)\)对于前m个,合成次数为偶数,我们就像刚才那么操作,最后一个1直接不用管。所以后手必赢。同样的我们可以把情况扩展到\(n = km + b\)这种情况。前k个\(m\)一样合并,后面也一样。
所以我们只需要算出合并次数就可以了
[BZOJ3609]-[Heoi2014]人人尽说江南好-神分析+博弈 dalao写的好qwq

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int t, n, m;

int main()
{
    re(t);
    while(t --)
    {
        re(n), re(m);
        lo x = (n / m) * (m - 1) + (n % m ? n % m - 1 : 0);
        printf("%d\n", (x & 1) ? 0 : 1);
    }
}

【算法】欧拉路&欧拉回路

欧拉路和欧拉回路是数学家欧拉在研究七桥问题的时候提出的,其具体历史这里就不多说了。
简单区分下欧拉回路和欧拉路
欧拉路就是说所有的边走且只走一次,欧拉回路就是在欧拉路的基础上更进一层,不仅所有的边都走一次,最后还回到了起点
其实就是我们熟知的一笔画呗
这个算法其实比较结论。对于有向图和无向图我们只需要找准条件随便写一下就行了,混合图……反正我是没看懂23333333
有向图存在欧拉路:所有顶点出度等于入度(且存在欧拉回路);或者是出两点外,所有点出度等于入度,这两个点中,出度大的为起点,入度大的为终点。
无向图存在欧拉路:所有点度数均为偶数(且存在欧拉回路);或有两个奇数点,他们一定是终点和起点
代码实现的话直接dfs就好啦。反正是一笔画,所以只要能走就一直走就行,记住dfs的时候要修改边权
luogu P1341 无序字母对

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

int n, dis[1010][1010], d[1010];

char s[1010], ans[1010], ss[1010];

void dfs(int no)
{
    for(ri i = 0; i < 52; i ++)
        if(dis[no][i])
            dis[no][i] = dis[i][no] = 0, dfs(i);
    ans[n --] = ss[no];
}

int main()
{
    re(n);
    for(ri i = 0; i < 26; i ++)
        ss[i] = 'A' + i;
    for(ri i = 26; i < 52; i ++)
        ss[i] = 'a' + i - 26;
    for(ri i = 1; i <= n; i ++)
    {
        scanf("%s", s + 1);
        int lx, ly;
        if(s[1] >= 'A' && s[1] <= 'Z')
            lx = s[1] - 'A';
        else
            lx = s[1] - 'a' + 26;
        if(s[2] >= 'A' && s[2] <= 'Z')
            ly = s[2] - 'A';
        else
            ly = s[2] - 'a' + 26;
        dis[lx][ly] = dis[ly][lx] = 1, d[lx] ++, d[ly] ++;
    }
    int num = 0, sta = -1;
    for(ri i = 0; i < 52; i ++)
        if(d[i] & 1)
        {
            num ++;
            sta = (sta == -1) ? i : sta;
        }   
    bool f = 0;
    if((num != 0 && num != 2) || f != 0)
    {
        printf("No Solution"); return 0;
    }
    if(sta == -1)
    {
        for(ri i = 0; i < 52; i ++)
            if(d[i] != 0)
            {
                sta = i; break;
            }
    }
    dfs(sta);
    if(n > 0)
    {
        printf("No Solution"); return 0;
    }
    printf("%s", ans);
}

【luoguP4370】[Code+#4]组合数问题2

暑假数学班的题目能让我咕咕咕到现在也是可以的,现在列表里面一堆模板还没做,noip前应该也不会做了2333333
这个题目是当时长者给我们讲的,脑洞真是清奇2333333333
首先分析这个题目的话,难点就在于,我们取模以后没法判断数字的大小了。这个时候有一个比较显然的思路,用double类型去存储组合数。但是组合数太大了,double类型只能存20位左右的精准数字,剩下的全部以浮点数的形式存储,那么怎么比较大小呢?
这也就是这个题目的关键,我们利用对数函数来判断数字的大小。我们统一取log2函数值的大小,这样就可以保证精度的同时判断大小了2333333
代码实现并不难,就是有点小恶心,开始wa了发,后来看到网上dalao的偷懒写法写了一发

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const lo ms = 1e6 + 10;

const lo mo = 1e9 + 7;

lo n, k, ans, fac[ms], inv[ms], fx[4] = {-1, 0, 1, 0}, fy[4] = {0, 1, 0, -1};

double fc[ms];

struct in
{
    lo x, y, c; double lg;
    inline in(lo a, lo b)
    {
        x = a, y = b, lg = fc[a] - fc[b] - fc[a - b];
        c = fac[a] * inv[b] % mo * inv[a - b] % mo;
    }
    inline bool operator < (const in &a) const
    {
        return lg < a.lg;
    }
};

priority_queue <in> qwq;

inline lo ksm(lo x, lo kk)
{
    lo a = x, rt = 1;
    while(kk)
        rt = rt * ((kk & 1) ? a : 1) % mo, a = a * a % mo, kk >>= 1;
    return rt;
}

int main()
{
    re(n), re(k), fac[0] = inv[0] = 1;
    for(ri i = 1; i <= n; i ++)
        fac[i] = (fac[i - 1] * i) % mo, fc[i] = fc[i - 1] + log2(i);
    inv[n] = ksm(fac[n], mo - 2);
    for(ri i = n - 1; i > 0; i --)
        inv[i] = inv[i + 1] * (i + 1) % mo;
    for(ri i = 0; i <= n; i ++)
        qwq.push(in(n, i));
    while(k --)
    {
        in qaq = qwq.top(); qwq.pop();
        ans = (ans + qaq.c) % mo, qaq.x --;
        qwq.push(in(qaq.x, qaq.y));
    }
    printf("%lld", ans);
}

【luoguP4364】 [九省联考2018]IIIDX

这个题目当初我考场就是个智障,连55的贪心都没做对qwq
首先一个比较显然的思路,我们先建树,\(\lfloor{\frac{i}{k}}\rfloor\)即为\(i\)的父亲,然后我们从大到小排个序,按照中序遍历给他赋值
但是这样的贪心在\(d_{i}\)重复的时候就gg了,因为我们发现可以用子树内部较大的值去换掉其他子树上比较小的值,因此我们要换个思路贪心
我们还是排序,每次去寻找一个可以容纳的下子树大小的区间,比如9 9 9 8 7 7 6 6 6 5,这个时候我们要插入一棵大小为7的子树。这个时候我们很显然会找到第7个数——6。然后我们再一直移动到第九个数——还是6。这是为了保证和我同一层的子树根尽可能的大,所以在我这棵子树的根并不会变小的时候,我多留点空为后面的子树。
这个时候我们为了防止别的树占用我们的预定好了的区间,所以我们需要设置一个辅助数组。设mi[i]表示权值大于等于i的数还有多少是可以选择的,那么排序过后mi[i]就可以转化为i左边还剩下多少数字可以使用。这个东西显然是满足单调不降的。那么我们就可以线段树去维护,每次预定的时候就区间减法
然而大多数人都选择了将减法变为加法,将被选取的区间和整个一段取了个补集,让他们进行区间加。这也就是为什么网上这么多题解嘴上都在说减法实际上一个减法都没有的原因。
每次变更父亲的时候提前钦定一波,每个子树选完了就再减回去。

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = (x << 1) + (x << 3) + a - '0', a = getchar();
    if(b == 1)
        x *= -1;
}

const int ms = 5e5 + 10;

int n, a[ms], fa[ms], sz[ms], cnt[ms], pos[ms];

struct in
{
    int l, r, mi, add;
}ter[ms << 2];

double k;

inline bool cmp(int a, int b)
{
    return a > b;
}

inline void up(int w)
{
    ter[w].mi = min(ter[w << 1].mi, ter[w << 1 | 1].mi);
}

inline void build(int w, int l, int r)
{
    ter[w] = (in){l, r, l};
    if(l == r)
        return;
    int mid = (l + r) >> 1;
    build(w << 1, l, mid), build(w << 1 | 1, mid + 1, r);
}

inline void down(int w)
{
    int &add = ter[w].add;
    ter[w << 1].add += add, ter[w << 1 | 1].add += add;
    ter[w << 1].mi += add, ter[w << 1 | 1].mi += add;
    add = 0;
}

void change(int w, int l, int r, int v)
{
    if(ter[w].l == l && ter[w].r == r)
    {
        ter[w].mi += v, ter[w].add += v; return;
    }
    int mid = (ter[w].l + ter[w].r) >> 1; down(w);
    if(r <= mid)
        change(w << 1, l, r, v);
    else if(l > mid)
        change(w << 1 | 1, l, r, v);
    else
        change(w << 1, l, mid, v), change(w << 1 | 1, mid + 1, r, v);
    up(w);
}

int ask(int w, int l)
{
    if(ter[w].l == ter[w].r)
        return ter[w].mi >= l ? ter[w].l : ter[w].l + 1;
    int mid = (ter[w].l + ter[w].r) >> 1; down(w);
    if(ter[w << 1 | 1].mi >= l)
        return ask(w << 1, l);
    return ask(w << 1 | 1, l);
}

int main()
{
    re(n), scanf("%lf", &k);
    for(ri i = 1; i <= n; fa[i] = (int)i / k, sz[i] = 1, re(a[i ++]));
    sort(a + 1, a + 1 + n, cmp);
    for(ri i = n; i >= 1; i --)
        sz[fa[i]] += sz[i];
    for(ri i = n - 1; i >= 0; i --)
        cnt[i] = (a[i] == a[i + 1]) ? cnt[i + 1] + 1 : 0;
    build(1, 1, n);
    for(ri i = 1; i <= n; i ++)
    {
        if(fa[i] && fa[i - 1] != fa[i])
            change(1, pos[fa[i]], n, sz[fa[i]] - 1);
        int x = ask(1, sz[i]);  x += cnt[x], cnt[x] ++, x -= (cnt[x] - 1);
        change(1, x, n, - sz[i]), pos[i] = x;
    }
    for(ri i = 1; i <= n; printf("%d ", a[pos[i ++]]));
}

【luoguP3746】组合数问题

这个题目……开始我yy了一个dp方程为\(dp[i] = dp[i – 1] + C_{nk}^{ik + r}\),然后矩阵一波发现emmmmmm里面有个组合数是会随着i的增长,m不断改变的。然后我就弃疗了。看完题解以后顿时觉得自己是个zz啊
组合数的递推公式\(C_{n}^{m}=C_{n-1}^{m-1}+C_{n-1}^{m}\),那么我这个式子加个\(\\sum\)也不是不行啊
于是式子变成了这样\(\sum C_{n}^{m} = \sum C_{n – 1}^{m – 1} + \sum C_{n – 1}^{m}\)
然后我就可以设置dp[i][j] = dp[i-1][j-1] + dp[i-1][j],dp[i][j]表示n为i的时候且m为a*k+j的时候的答案
这个东西显然可以矩阵快速幂优化
但是其实这个方程这么写是错的
j0的时候j-1不就是-1了吗
那么真正的方程式为dp[i][j] = dp[i – 1][j] + dp[i – 1][(j – 1 + k) % k]

#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

inline void re(lo &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

lo n, p, k, r;

struct in
{
    lo a[55][55];
}a, ans;

in operator * (in x, in y)
{
    in c; memset(c.a, 0, sizeof(c.a));
    for(ri kk = 0; kk < k; kk ++)
        for(ri i = 0; i < k; i ++)
            for(ri j = 0; j < k; j ++)
                c.a[i][j] = (c.a[i][j] + x.a[i][kk] * y.a[kk][j] % p) % p;
    return c;
}

int main()
{
    re(n), re(p), re(k), re(r), n *= k;
    for(ri i = 0; i < k; i ++)
        a.a[i][i] ++, a.a[(i - 1 + k) % k][i] ++;
    for(ri i = 0; i < k; i ++)
        ans.a[i][i] = 1;
    while(n)
    {
        if(n & 1)
            ans = a * ans;
        a = a * a, n >>= 1;
    }
    printf("%lld", ans.a[0][r]);
}

【洛谷P4882】lty loves 96!

开始的时候看到这个题目有点懵逼,因为没仔细看条件,以为有几个条件是保证对于每一个abc
我们考虑,对于一个数字,只有最后两个加进去的才是有意义的,之前怎么摆最后只需要保留已经满不满足第三个大条件即可。所以我们可以设一个裸的状态dp[i][j][x][y][0/1]表示当前已经放了i位数,有j个9/6,第i位为y,i-1位为x,是否满足第三个条件的方案数。那么转移方程也比较显然了,这里就不再写转移方程,具体分析下情况。我们每次转移的时候枚举最后三位数。如果后三位数能满足条件,则能从非法转移到合法,否则只能从非法转移到非法。原本合法的一定转移过来。
注意,这个题目是认为倒数第三位为模数c的qwq

// luogu-judger-enable-o2
#include<algorithm>
#include<iostream>
#include<cstring>
#include<climits>
#include<cstdio>
#include<vector>
#include<queue>
#include<cmath>
#include<queue>
#include<map>
#define ri register int

using namespace std;

typedef long long lo;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

inline void re(int &x)
{
    x = 0;
    char a; bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b == 1)
        x = - x;
}

const int mo = 1e8;

int n, m, t;

struct in
{
    int a[55];
    void init(int x)
    {
        memset(a, 0, sizeof(a)), a[a[0] = 1] = x;
    }
}dp[2][51][10][10][2];

in operator + (in a, in b)
{
    int d = max(a.a[0], b.a[0]);
    for(ri i = 1; i <= d; i ++)
        a.a[i] += b.a[i], a.a[i + 1] += a.a[i] / mo, a.a[i] %= mo;
    while(a.a[d + 1] > 0)
        d ++, a.a[d + 1] += a.a[d] / mo, a.a[d] %= mo;
    a.a[0] = d; return a;
}

inline void print(in a)
{
    int mx = a.a[0]; printf("%d", a.a[mx]);
    for(ri i = mx - 1; i >= 1; i --)
        printf("%08d", a.a[i]);
    printf("\n");
}

inline bool check(int x)
{
    return x == 9 || x == 6;
}

int main()
{
    //freopen("data.txt", "r", stdin);
    //freopen("qwq.out", "w", stdout);
    re(n), re(m);
    if(n == 1)
    {
        printf(m == 0 ? "9" : "2"); return 0;
    }
    else if(n == 2)
    {
        printf(m == 0 ? "90" : (m == 1 ? "34" : "4")); return 0;
    }
    for(ri i = 0; i < 10; i ++)
        for(ri j = 0; j < 10; j ++)
            dp[t][check(i) + check(j)][i][j][0].init(1);
    for(ri i = 3; i <= n; i ++)
    {
        t ^= 1;
        for(ri j = 0; j < i; j ++)
            for(ri x = 0; x < 10; x ++)
                for(ri y = 0; y < 10; y ++)
                    dp[t][j][x][y][0].init(0), dp[t][j][x][y][1].init(0); 
        for(ri j = 0; j < i; j ++)
            for(ri x = 0; x < 10; x ++)
                for(ri y = 0; y < 10; y ++)
                    for(ri z = 0; z < 10; z ++)
                    {
                        dp[t][check(z) + j][z][y][1] = dp[t][check(z) + j][z][y][1] + dp[t ^ 1][j][y][x][1];
                        if(check(x + y + z)
                        || (x && check((y * y + z * z) % x)))
                            dp[t][check(z) + j][z][y][1] = dp[t][check(z) + j][z][y][1] + dp[t ^ 1][j][y][x][0];
                        else
                            dp[t][check(z) + j][z][y][0] = dp[t][check(z) + j][z][y][0] + dp[t ^ 1][j][y][x][0];
                    }
    }
    in ans; ans.init(0);
    for(ri i = 1; i < 10; i ++)
        for(ri j = 0; j < 10; j ++)
            for(ri k = m; k <= n; k ++)
                ans = ans + dp[t][k][i][j][1];
    print(ans);
    fclose(stdin);
    fclose(stdout);
    return 0;
}