跳转至

矩阵优化

因为矩阵满足交换律和结合律(自行搜索),所以矩阵快速幂和快速幂一样,只不过快速幂中的数乘(res*=a,a*=a变成了两个矩阵相乘)。而因为矩阵可以实现并行计算,并且将下一个阶段的结果存储在结果矩阵中,我们通常使用矩阵来加速递推式的推导。

例题 #1 Decoding Genome

题目描述

最近,对火星的一次秘密研究拉开帷幕,在这次探寻中,科学家们发现了火星上的DNA。

他们发现,火星DNA包含 \(n\) 个核苷酸,且由 \(m\) 种不同核苷酸构成(编号为 \(1\)~\(m\))。

他们还发现出于某些特殊的原因,存在 \(k\) 对核苷酸,一对这样的核苷酸不会连续出现在火星DNA中。

例如存在一对这样的核苷酸 (1,2) ,则核苷酸2不能紧连着核苷酸1(即DNA中不会连续出现 1,2);但是可以连续出现 2,1

科学家们想知道有多少种不同的火星DNA(如果两个DNA存在一个位置,使得两个DNA的该位置的核苷酸不同,则认为这两个DNA不同)。

输入

第一行包含 \(3\) 个整数 \(n,m,k\) ,含义见题目描述;

接下来 \(k\) 行,每行包含一个长度为 \(2\) 的字符串,字符串由小写字母和大写字母构成——字母'a'-'z'表示核苷酸 \(1\)\(26\),'A'-'Z'表示核苷酸 \(27\)\(52\);表示题目描述中的 \(k\) 对核苷酸。

输出

一个整数,表示有多少种不同的DNA,答案对 \((10^9+7)\) 取模。

数据规模

\(1\leq n\leq 10^{15},1\leq m\leq 52,0\leq k\leq m^2\)


首先我们需要写出基本的dp转移

\(f_{i,j}=\sum f_{i-1,k}[used_{j,k}==1]\)

\(used_{j,k}\)表示(j,k)顺序相邻对是否冲突。

for(itn i=2;i<=n;i++){
     for(itn j=1;j<=m;j++){
         for(int k=1;k<=m;k++){
             f[i][j]+=f[i-1][k]*(used[k][j]==0);
             f[i][j]%=MOD;
         }
     }
 }

然后我们发现它可以写成递推式的形式,于是我们可以使用矩阵来加速它。

/*                                                                                
                      Keyblinds Guide
                    ###################
      @Ntsc 2024

      - Ctrl+Alt+G then P : Enter luogu problem details
      - Ctrl+Alt+B : Run all cases in CPH
      - ctrl+D : choose this and dump to the next
      - ctrl+Shift+L : choose all like this
      - ctrl+K then ctrl+W: close all
      - Alt+la/ra : move mouse to pre/nxt pos'

*/
#include <bits/stdc++.h>
#include <queue>
using namespace std;

#define rep(i, l, r) for (int i = l, END##i = r; i <= END##i; ++i)
#define per(i, r, l) for (int i = r, END##i = l; i >= END##i; --i)
#define pb push_back
#define mp make_pair
#define int long long
#define pii pair<int, int>
#define ps second
#define pf first

// #define innt int
#define itn int
// #define inr intw
// #define mian main
// #define iont int

#define rd read()
int read(){
    int xx = 0, ff = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            ff = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
      xx = xx * 10 + (ch - '0'), ch = getchar();
    return xx * ff;
}
void write(int out) {
    if (out < 0)
        putchar('-'), out = -out;
    if (out > 9)
        write(out / 10);
    putchar(out % 10 + '0');
}

#define ell dbg('\n')
const char el='\n';
const bool enable_dbg = 1;
template <typename T,typename... Args>
void dbg(T s,Args... args) {
    if constexpr (enable_dbg){
    cerr << s;
    if(1)cerr<<' ';
        if constexpr (sizeof...(Args))
            dbg(args...);
    }
}

#define zerol = 1
#ifdef zerol
#define cdbg(x...) do { cerr << #x << " -> "; err(x); } while (0)
void err() { cerr << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cerr << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cerr << a << ' '; err(x...); }
#else
#define dbg(...)
#endif


const int N = 3e5 + 5;
const int INF = 1e18;
const int M = 1e7;
const int MOD = 1e9 + 7;


int n,m,K;



struct mxu{
    int a[66][66];


    void base(){
        for(int i=1;i<=m;i++){
            a[i][i]=1;
        }
    }

    void clear(){

        for(int i=0;i<=60;i++){
            for(int j=0;j<=60;j++){
                a[i][j]=0;
            }
        }
    }
}used,f;


mxu add(mxu a,mxu b){
    mxu r;
    r.clear();//!不清会挂
    for(int i=1;i<=60;i++){
        for(itn j=1;j<=60;j++){
            for(int k=1;k<=60;k++){
                (r.a[i][j]+=a.a[i][k]*b.a[k][j])%=MOD;
            }
        }
    }
    return r;
}

// 设 f_i,j 表示确定了前 i个长度的核苷酸,其中第 i个的核苷酸的种类是 j的方案数

inline int C(char c){
    if(c>='a')return c-'a'+1;
    return c-'A'+27;
}

mxu ksm(mxu a,int b){
    mxu r;
    r.base();
    while(b){
        if(b&1)r=add(r,a);
        a=add(a,a);
        b>>=1;    
    }
    return r;
}



void solve(){
    n=rd,m=rd,K=rd;


    for(int i=1;i<=m;i++){
        for(int j=1;j<=m;j++){
            used.a[i][j]=1;
        }
    }
    for(int i=1;i<=K;i++){
        string s;
        cin>>s;
        used.a[C(s[0])][C(s[1])]=0;
    }

    for(int i=1;i<=m;i++){
        f.a[1][i]=1;
    }

    used=ksm(used,n-1);// 变换n-1次
    f=add(f,used);


    int ans=0;
    for(itn i=1;i<=m;i++){
        ans+=f.a[1][i];
        ans%=MOD;
    }


    cout<<ans<<endl;
}

signed main() {
    // freopen(".in","r",stdin);
    // freopen(".in","w",stdout);

    int T=1;
    while(T--){
        solve();
    }
    return 0;
}

例题 #2 Product Oriented Recurrence

题目描述

\(x \geq 4\) 时,\(f_x = c^{2x - 6} \cdot f_{x - 1} \cdot f_{x - 2} \cdot f_{x - 3}\)

现在已知 \(n,f_1,f_2,f_3,c\) 的值,求 \(f_n\) 的值,对 \(10^9 + 7\) 取模。

\((4 \leq n \leq 10^{18},1 \leq f_1,f_2,f_3,c \leq 10^9)\)


这道题其实不是dp,但是一道矩阵加速的好题。

我们发现这道题的转移式中包含了不知一个版本的信息,那么我们就不太好正解进行转移。

我们先模拟一下,试着把\(f_7\)写成\(f_1^if_2^jf_3^kc^a\)的形式。

不知道你有没有发现,在模拟的过程中,我们可以使用矩阵加速这个模拟的过程。也就是说本题的矩阵是用来加速计算最终答案的指数的。(也就是上面的i,j,k,a)。

最后有一个细节,就是i,j,k,a可能会爆LL,我们需要使用欧拉定理将其约束在LL范围内。

/*                                                                                
                      Keyblinds Guide
                    ###################
      @Ntsc 2024

      - Ctrl+Alt+G then P : Enter luogu problem details
      - Ctrl+Alt+B : Run all cases in CPH
      - ctrl+D : choose this and dump to the next
      - ctrl+Shift+L : choose all like this
      - ctrl+K then ctrl+W: close all
      - Alt+la/ra : move mouse to pre/nxt pos'

*/
#include <bits/stdc++.h>
#include <queue>
using namespace std;

#define rep(i, l, r) for (int i = l, END##i = r; i <= END##i; ++i)
#define per(i, r, l) for (int i = r, END##i = l; i >= END##i; --i)
#define pb push_back
#define mp make_pair
#define int long long
#define pii pair<int, int>
#define ps second
#define pf first

// #define innt int
#define itn int
// #define inr intw
// #define mian main
// #define iont int

#define rd read()
int read(){
    int xx = 0, ff = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            ff = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
      xx = xx * 10 + (ch - '0'), ch = getchar();
    return xx * ff;
}
void write(int out) {
    if (out < 0)
        putchar('-'), out = -out;
    if (out > 9)
        write(out / 10);
    putchar(out % 10 + '0');
}

#define ell dbg('\n')
const char el='\n';
const bool enable_dbg = 1;
template <typename T,typename... Args>
void dbg(T a,Args... args) {
    if constexpr (enable_dbg){
    cerr << a;
    if(1)cerr<<' ';
        if constexpr (sizeof...(Args))
            dbg(args...);
    }
}

#define zerol = 1
#ifdef zerol
#define cdbg(x...) do { cerr << #x << " -> "; err(x); } while (0)
void err() { cerr << endl; }
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x) { for (auto v: a) cerr << v << ' '; err(x...); }
template<typename T, typename... A>
void err(T a, A... x) { cerr << a << ' '; err(x...); }
#else
#define dbg(...)
#endif


const int N = 3e5 + 5;
const int INF = 1e18;
const int M = 1e7;
const int MOD = 1e9 + 7;

int n;
int c;
int f1, f2, f3;


struct matrix {
    int a[6][6];
    matrix() {memset(a, 0, sizeof a);};
    matrix operator * (const matrix &b) const {
        matrix res;
        for(int i = 1;i <= 5;i++) {
            for(int j = 1;j <= 5;j++) {
                for(int k = 1;k <= 5;k++) {
                    res.a[i][j] += a[i][k] * b.a[k][j];
                    res.a[i][j] %= (MOD - 1);
                }
            }
        }
        return res;
    }
}mx, mx2, w, x, y, z;



void init() {
    mx.a[1][1] = mx.a[1][2] = mx.a[2][1] = 1;
    mx.a[2][3] = mx.a[3][1] = mx.a[4][1] = 1;
    mx.a[4][4] = 1, mx.a[5][4] = 1, mx.a[5][5] = 1;
    mx2.a[1][1] = mx2.a[1][2] = mx2.a[2][1] = 1;
    mx2.a[3][1] = mx2.a[2][3] = 1;


    w.a[1][1] = 0, 
    w.a[1][2] = 0, 
    w.a[1][3] = 0, 
    w.a[1][4] = 2, 
    w.a[1][5] = 2;


    x.a[1][3] = 1;
    y.a[1][2] = 1;
    z.a[1][1] = 1;
}


void ksmmx(int b) {
    while(b) {
        if(b & 1) w = w * mx;
        mx = mx * mx;
        b >>= 1;
    }
}



void ksmmx2(int b) {
    while(b) {
        if(b & 1) {
            x = x * mx2;
            y = y * mx2;
            z = z * mx2;
        }
        mx2 = mx2 * mx2;
        b >>= 1;
    }
}

int ksm(int a,int b){
    int r=1;
    while(b){
        if(b&1)r=(r*a)%MOD;
        a=(a*a)%MOD;
        b>>=1;
    }
    return r;
}


void solve(){
    n=rd,f1=rd,f2=rd,f3=rd,c=rd;
    init();
    if(n == 1) cout << f1 << endl;
    else if(n == 2) cout << f2 << endl;
    else if(n == 3) cout << f3 << endl;
    else {
        ksmmx(n - 3);// 转移c的幂
        ksmmx2(n - 3);// 转移f1,f2,f3的幂
        int p1 = w.a[1][1], p2 = x.a[1][1], p3 = y.a[1][1], p4 = z.a[1][1];
        cout << (ksm(c, p1) % MOD * ksm(f1, p2) % MOD * ksm(f2, p3) % MOD * ksm(f3, p4) % MOD) % MOD << endl;
    }
}





signed main() {
    // freopen(".in","r",stdin);
    // freopen(".in","w",stdout);

    int T=1;
    while(T--){
        solve();
    }
    return 0;
}