【算法】excrt&exlucas

其实这俩我早就会过,当时还是个傻子不会写数学公式,结果昨晚复习发现自己想不起来怎么做了,然后就滚过来复习下

\begin{cases} x \equiv a_1 (mod \; m_1)\\ x \equiv a_2 (mod \; m_2)\\ x \equiv a_3 (mod \; m_3)\\ ……\\ x \equiv a_n (mod \; m_n)\\ \end{cases}

现在m之间两两不互质,求这个方程组的一组解

假设我们已经求出前i-1组方程的一组解x,设M = lcm(m_1, m_2, m_3, ……, m_{i-1})

那么x+kM是前i-1个方程的通解

那么现在我们就求

x+kM \equiv a_i (mod \; m_i)

然后这个东西就可以通过n次exgcd求出来了。

然后记一个小扩展

假如我们目前知道一个方程

ax \equiv b (mod \; c)

这个东西乍一看不符合excrt的初始条件,但我们可以通过变换形式来构造初始方程

首先我们可以很轻松的用一次exgcd求出这个方程的一组解x_i

然后就有这么一个通解形式

x = x_i + k\frac{c}{(a,b)}

那么我们把这个式子放在mod \; \frac{c}{(a,b)}进行

那么就得到了

x \equiv x_i (mod \; \frac{c}{(a,b)})

这就变成了符合excrt的形式

现在我们要求C_n^m (mod \; P),并且P不是质数

首先质因数分解,P = \prod_{i}p_i^{k_i}

因此我们需要分别计算C_n^m (mod \; p_i^{k_i}),并利用excrt合并即可

现在重点是怎么计算组合数

拆成\frac{n!}{m!(n-m)!}

之后考虑计算阶乘及其逆元

首先对于p的倍数,先把他们拿出来,计算(\lfloor \frac{n}{p} \rfloor)!对于剩下的,不会超过p^k就会有一个循环节

luoguP2183[国家集训队]礼物

首先一眼可以得出式子

C_n^{w_1}C_{n-w_1}^{w_2}……C_{w_n}^{w_n}

这就是个扩展卢卡斯的式子,直接算就是了,代码挺不好写的就是了

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#include <bitset>
#include <queue>
#include <cmath>
#include <map>
#include <set>
#define ri register int

using namespace std;

inline char gch()
{
    static char buf[100010], *h = buf, *t = buf;
    return h == t && (t = (h = buf) + fread(buf, 1, 100000, stdin), h == t) ? EOF : *h ++;
}

typedef long long lo;

template <typename int_qwq>

inline void re(int_qwq &x)
{
    x = 0;
    char a;
    bool b = 0;
    while(!isdigit(a = getchar()))
        b = a == '-';
    while(isdigit(a))
        x = x * 10 + a - '0', a = getchar();
    if(b)
        x = - x;
}

lo P, prime[220], ta, n, m, w[20], ci[220], c[220], pri[220];

inline lo ksm(lo x, lo k, lo mo)
{
    lo rt = 1, a = x;
    for(; k; a = a * a % mo, k >>= 1)
        if(k & 1)
            rt = rt * a % mo;
    return rt;
}

inline lo mul(lo a, lo b, lo mo)
{
    lo rt = ((long double)a / mo * b + 1.0e-8);
    rt = a * b - rt * mo;
    return rt < 0 ? rt + mo : rt;
}

void exgcd(lo a, lo b, lo &x, lo &y)
{
    if(b == 0)
    {
        x = 1, y = 0; return;
    }
    exgcd(b, a % b, x, y);
    lo p = x, q = y;
    x = y, y = p - a / b * q;
}

inline lo gcd(lo a, lo b)
{
    lo C;
    while(b)
        C = a, a = b, b = C % b;
    return a;
}

inline lo inv(lo a, lo mo)
{
    lo lx, ly; exgcd(a, mo, lx, ly); return (lx % mo + mo) % mo;
}

inline lo askc(lo x, lo y, lo mo)
{
    if(y > x)
        return 0;
    lo rt = 1, a, b;
    for(ri i = 1; i <= y; i ++)
        a = (x - i + 1) % mo, b = inv(i % mo, mo), rt = rt * a % mo * b % mo;
    return rt;
}

lo lucas(lo x, lo y, lo mo)
{
    if(y == 0)
        return 1;
    return askc(x % mo, y % mo, mo) * lucas(x / mo, y / mo, mo) % mo;
}

lo fac(lo x, lo a, lo mo)
{
    if(x == 0)
        return 1;
    lo la = 1, rt;
    for(ri i = 1; i <= mo; i ++)
        if(i % a)
            la = la * i % mo;
    rt = ksm(la, x / mo, mo);
    for(ri i = x / mo * mo + 1; i <= x; i ++)
        if(i % a)
            rt = rt * i % mo;
    return rt * fac(x / a, a, mo) % mo;
}

inline lo exlucas(lo x, lo y, lo a, lo mo)
{
    lo t1, t2, t3, s = 0, tmp;
    for(ri i = x; i > 0; i /= a)
        s += i / a;
    for(ri i = y; i > 0; i /= a)
        s -= i / a;
    for(ri i = x - y; i > 0; i /= a)
        s -= i / a;
    tmp = ksm(a, s, mo), t1 = fac(x, a, mo), t2 = fac(y, a, mo), t3 = fac(x - y, a, mo);
    return tmp * t1 % mo * inv(t2, mo) % mo * inv(t3, mo) % mo;
}

inline lo calc(lo x, lo y)
{
    for(ri i = 1; i <= ta; i ++)
        c[i] = (ci[i] == 1) ? lucas(x, y, prime[i]) : exlucas(x, y, prime[i], pri[i]);  
    lo rt = c[1], M = pri[1];
    for(ri i = 2; i <= ta; i ++)
    {
        lo lx, ly, lz = (c[i] - rt) % pri[i]; lz += (rt < 0) ? pri[i] : 0;
        exgcd(M, pri[i], lx, ly), lx = lx * lz % pri[i];
        rt += lx * M, M = M * pri[i] / gcd(M, pri[i]), rt = (rt + M) % M;
    }
    return rt;
}

int main()
{
    re(P), re(n), re(m);
    lo tmps = 0;
    for(ri i = 1; i <= m; i ++)
        re(w[i]), tmps += w[i];
    if(tmps > n)
    {
        puts("Impossible");
        return 0;
    }
    tmps = P;
    for(ri i = 2; i * i <= tmps; i ++)
        if(tmps % i == 0)
        {
            prime[++ ta] = i, pri[ta] = 1;
            while(tmps % i == 0)
                ci[ta] ++, tmps /= i, pri[ta] *= i;
        }
    if(tmps > 1)
        prime[++ ta] = tmps, ci[ta] = 1, pri[ta] = tmps;
    lo las = 1;
    for(ri i = 1; i <= m; i ++)
        tmps = calc(n, w[i]), n -= w[i], las = las * tmps % P;// cout << las << '\n';
    printf("%lld\n", las);
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注