AOJ 2437 : DNA
問題リンク : http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2437&lang=jp
問題概要 :
日本語なので略
解法:
DP + 半分全列挙で解いた
まず、i文字目(1<=i<=Na+Nt+Ng+Nc)にくる文字が何かを計算する
与えられる文法に含まれる非終端記号間の関係をグラフにするとDAGなので、これは再帰などでその通り計算すると良い
が計算途中で文字数がNa+Nt+Ng+Ncを超えるようなら、答えは0なので枝刈すること ( 出ないとTLEする )
文字数がNa+Nt+Ng+Ncを超えるようなら、枝刈するので、計算にはそのくらいの時間しかかからない
次に、区間[0,(Na+Nt+Ng+Nc)/2)と[(Na+Nt+Ng+Nc)/2,(Na+Nt+Ng+Nc))のそれぞれについて、
その区間で作れる文字列とそれが何通りあるかを計算する
これは動的計画法で求めると良い dp[i文字目][Aの数][Tの数][Gの数][Cの数] := A, T, G, Cがそれぞれ(Aの数), (Tの数), (Gの数), (Cの数)だけ出現する文字列が何通りあるか
(i文字目)はその最大値分だけ配列を用意するとMLEするので、2にして使い回すこと
O((Na+Nt+Ng+Nc) * Na * Nt * Ng * Nc)
半分に分けたので2.5 * 10^9くらい
AOJなら3 * 10^9は一瞬なのでOK
最後に、2つの区間について半分全列挙で答えを求める
O(Na * Nt * Ng * Nc)
コード (デバッグコード付き debug = 1にするとデバッグモード) :
#include<bits/stdc++.h> #define REP(i,s,n) for(int i=s;i<n;++i) #define rep(i,n) REP(i,0,n) #define EPS (1e-8) #define equals(a,b) (fabs((a)-(b))<EPS) #define pb push_back #define ALL(x) x.begin(),x.end() using namespace std; typedef long long ll; bool LT(double a,double b) { return !equals(a,b) && a < b; } bool LTE(double a,double b) { return equals(a,b) || a < b; } #define A_ 0 #define T_ 1 #define G_ 2 #define C_ 3 #define NONTERMINAL 0 #define TERMINAL 1 struct Data { int type; // NONTERMINAL or TERMINAL int dst; // for NONTERMINAL int bit; // for TERMINAL }; #define MAX 60 vector<int> N; int m; vector<Data> G[MAX]; const ll mod = 1000000007LL; bool debug = 0; inline int toIDX(char c) { if( c == 'A' ) return A_; if( c == 'T' ) return T_; if( c == 'G' ) return G_; if( c == 'C' ) return C_; assert(false); } inline char toCHR(int i) { if( i == A_ ) return 'A'; if( i == T_ ) return 'T'; if( i == G_ ) return 'G'; if( i == C_ ) return 'C'; assert(false); } bool fail; vector<int> chars; int enum_char_dfs(int cur) { if( fail ) return -1; int ret = 0; rep(i,(int)G[cur].size()) { Data &d = G[cur][i]; if( d.type == NONTERMINAL ) { ret += enum_char_dfs(d.dst); if( ret >= N[0] + N[1] + N[2] + N[3] + 1 ) { fail = true; return -1; } } else { chars.pb(d.bit); //++len; ++ret; } } return ret; } int dp[2][52][52][52][52]; int solve_DP(int L,int R) { //[L,R) memset(dp,0,sizeof dp); dp[0][0][0][0][0] = 1; int idx = 0; REP(i,L,R) { rep(w,N[0]+1) rep(x,N[1]+1) rep(y,N[2]+1) rep(z,N[3]+1) dp[(idx+1)&1][w][x][y][z] = 0; rep(a,N[0]+1) { rep(t,N[1]+1) { rep(g,N[2]+1) { rep(c,N[3]+1) { int &score = dp[idx&1][a][t][g][c]; if( score == 0 ) continue; if( (chars[i]>>A_) & 1 ) ( dp[(idx+1)&1][a+1][t][g][c] += score ) %= mod; if( (chars[i]>>T_) & 1 ) ( dp[(idx+1)&1][a][t+1][g][c] += score ) %= mod; if( (chars[i]>>G_) & 1 ) ( dp[(idx+1)&1][a][t][g+1][c] += score ) %= mod; if( (chars[i]>>C_) & 1 ) ( dp[(idx+1)&1][a][t][g][c+1] += score ) %= mod; } } } } ++idx; } return idx; } const ll ccoef = 1000000; const ll gcoef = 10000; const ll tcoef = 100; inline ll toHash(int a,int t,int g,int c) { return a + t * tcoef + g * gcoef + c * ccoef; } void toMap(int idx, map<ll,ll> &mp) { rep(a,N[0]+1) { rep(t,N[1]+1) { rep(g,N[2]+1) { rep(c,N[3]+1) { ll hs = toHash(a,t,g,c); ll score = dp[idx][a][t][g][c]; if( score == 0 ) continue; mp[hs] = score; } } } } } void solve() { if( debug ) { cout << "* Graph ---" << endl; rep(i,m) { cout << " - i = " << i << "-th:" << endl; rep(j,(int)G[i].size()) { Data &d = G[i][j]; if( d.type == NONTERMINAL ) { cout << " [NONTERMINAL] : to " << d.dst << endl; } else { assert( d.bit != -1 ); bitset<4> BIT(d.bit); cout << " [ TERMINAL] : bit = " << BIT << endl; } } puts(""); } } // enumerate chars chars.clear(); fail = false; int len = enum_char_dfs(0); if( fail ) { puts("0"); return ; } if( debug ) { assert( len == (int)chars.size() ); cout << "* enumerate chars ---" << endl; cout << " - len = " << len << endl; cout << " "; rep(i,len) { cout << "["; rep(j,4) { if( (chars[i]>>j) & 1 ) { cout << toCHR(j); } } cout << "] "; } puts(""); } // Dynamic Programming!! if( len > N[0] + N[1] + N[2] + N[3] ) { puts("0"); return ; } int sz = len / 2; if( debug ) { cout << "[SPLIT] -- [" << 0 << "," << sz << ") and [" << sz << "," << len << ")" << endl; } map<ll,ll> mp[2]; int lst; lst = solve_DP(0,sz); toMap(lst&1,mp[0]); if( debug ) { cout << "* mp[0] Info ---" << endl; for(auto v : mp[0]) { ll hs = v.first; int c = hs / ccoef; hs -= ( c * ccoef ); int g = hs / gcoef; hs -= ( g * gcoef ); int t = hs / tcoef; hs -= ( t * tcoef ); int a = hs; ll score = v.second; cout << " (A,T,G,C) = (" << a << "," << t << "," << g << "," << c << ")" << endl; cout << " score = " << score << endl; } } lst = solve_DP(sz,len); toMap(lst&1,mp[1]); if( debug ) { cout << "* mp[1] Info ---" << endl; for(auto v : mp[1]) { ll hs = v.first; int c = hs / ccoef; hs -= ( c * ccoef ); int g = hs / gcoef; hs -= ( g * gcoef ); int t = hs / tcoef; hs -= ( t * tcoef ); int a = hs; ll score = v.second; cout << " (A,T,G,C) = (" << a << "," << t << "," << g << "," << c << ")" << endl; cout << " score = " << score << endl; } } // hanbun-zennrekkyo ll answer = 0; for(auto v : mp[0]) { ll hs = v.first; int c = hs / ccoef; hs -= ( c * ccoef ); int g = hs / gcoef; hs -= ( g * gcoef ); int t = hs / tcoef; hs -= ( t * tcoef ); int a = hs; ll score = v.second; if( a > N[0] || t > N[1] || g > N[2] || c > N[3] ) continue; int rem_a = N[0] - a; int rem_t = N[1] - t; int rem_g = N[2] - g; int rem_c = N[3] - c; ll next_hs = toHash(rem_a,rem_t,rem_g,rem_c); if( mp[1].count(next_hs) ) { ( answer += ( ( score * mp[1][next_hs] ) % mod ) ) %= mod; } } cout << answer << endl; } void parse(string s, map<string,int> &mp, int &id) { rep(i,(int)s.size()) if( s[i] == ':' ) { s[i] = ' '; break; } stringstream ss; ss << s; vector<string> vec; while( ss >> s ) vec.pb(s); if( !mp.count(vec[0]) ) { mp[vec[0]] = id++; } int cur = mp[vec[0]]; REP(i,1,(int)vec.size()) { string &t = vec[i]; if( t[0] == '[' ) { int ptr = 1, bit = 0; while( ptr < (int)t.size() && t[ptr] != ']' ) bit = bit | (1<<toIDX(t[ptr++])); G[cur].pb((Data){TERMINAL,-1,bit}); } else { if( !mp.count(vec[i]) ) { mp[vec[i]] = id++; } int nex = mp[vec[i]]; G[cur].pb((Data){NONTERMINAL,nex,-1}); } } } int main() { N.resize(4); rep(i,4) cin >> N[i]; cin >> m; cin.ignore(); string line; map<string,int> mp; int id=0; rep(i,m) { getline(cin,line); parse(line,mp,id); } solve(); return 0; }