首先考虑n^2的做法那就是
for(int i = 1; i <= 3; i++){
for(int j = 1; j <= 3 * n; j++){
for(int k = 1; k <= n; k++){
dp[i][j + k] += dp[i - 1][j];
我们从中可以看出 当 j = 1的时候他对所有 j + 1 ~ j + n 那么 j = 2 就对 j + 2 ~ j + n + 1是有贡献的
所以我们只需求一个前缀和 即可
#include<iostream>using namespace std;const int N = 3e6 + 10;
typedef long long ll;ll dp[4][N],dp2[4][N];int main(){ll n,m;cin >> n >> m;dp[0][0] = 1;for(int j = 1; j <= n; j++){dp[1][j] = 1;}for(int j = 1; j <= 3 * n; j++){dp2[1][j] = dp2[1][j - 1] + dp[1][j];}for(int i = 2; i <= 3; i++){for(int j = 1; j <= 3 * n; j++){dp[i][j] += dp2[i - 1][j - 1];if(j >= n + 1){dp[i][j] -= dp2[i - 1][j - n - 1];}}for(int j = 1; j <= 3 * n; j++){dp2[i][j] = dp2[i][j - 1] + dp[i][j];}}ll x;for(int i = 3; i <= 3 * n; i++){if(m <= dp2[3][i]){m -= dp2[3][i - 1];x = i;break;}}for(int i = 1; i <= n; i++){ll minn = max(1ll,x - i - n);ll maxn = min(n,x - i - 1);if(minn > maxn) continue;if(m > maxn - minn + 1){m -= maxn - minn + 1;continue;}ll y = minn + m - 1;ll z = x - i - y;cout << i << " " << y << " " << z << endl;return 0;}return 0;
}