题目:数一
思路:
代码如下:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
public class Main {
//快速输入
static BufferedReader br=new BufferedReader(new InputStreamReader(System.in));
//快速输出
static PrintWriter pw=new PrintWriter(new OutputStreamWriter(System.out));
static int mod=1_000_000_009;
public static void main(String[] args) throws IOException {
String[] s=br.readLine().split(" ");
int k=Integer.parseInt(s[0]);
int b=Integer.parseInt(s[1]);
//记录遍历到i位时余数为j二进制中1的总数
long[][] dp=new long[b][k];
//记录遍历到i位时余数为j的个数
long[][] tp=new long[b][k];
int[] pow=new int[b+1];
pow[0]=1;
for(int i=1;i<=b;i++) {
//提前计算2^i%k
pow[i]=pow[i-1]%k*2%k;
}
long ans=k==1?1:0;
tp[0][0]=k==1?2:1;
dp[0][0]=k==1?1:0;
if(k>1) {
dp[0][1]=1;
tp[0][1]=1;
}
for(int i=1;i<b;i++) {
for(int j=0;j<k;j++) {
int t=0;
if(j>pow[i]) {
t=j-pow[i];
}
else {
t=(j-pow[i]+k)%k;
}
//当前位填入1,计算余数为j的二进制中1的总数
dp[i][j]=(dp[i][j]+tp[i-1][t]+dp[i-1][t])%mod;
//当前位填入1,计算余数为j的个数
tp[i][j]=(tp[i][j]+tp[i-1][t])%mod;
if(j==0) {
ans=(ans+dp[i][j])%mod;
}
//当前位填入0,余数为j
dp[i][j]=(dp[i][j]+dp[i-1][j])%mod;
tp[i][j]=(tp[i][j]+tp[i-1][j])%mod;
}
}
pw.println(ans);
pw.flush();
}
}