土下座しながら探索中

主に競技プログラミング

AOJ 2437 : DNA

問題リンク : http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2437&lang=jp

問題概要 :
日本語なので略

解法:
DP + 半分全列挙で解いた

まず、i文字目(1<=i<=Na+Nt+Ng+Nc)にくる文字が何かを計算する
与えられる文法に含まれる非終端記号間の関係をグラフにするとDAGなので、これは再帰などでその通り計算すると良い
が計算途中で文字数がNa+Nt+Ng+Ncを超えるようなら、答えは0なので枝刈すること ( 出ないとTLEする )
文字数がNa+Nt+Ng+Ncを超えるようなら、枝刈するので、計算にはそのくらいの時間しかかからない

次に、区間[0,(Na+Nt+Ng+Nc)/2)と[(Na+Nt+Ng+Nc)/2,(Na+Nt+Ng+Nc))のそれぞれについて、
その区間で作れる文字列とそれが何通りあるかを計算する
これは動的計画法で求めると良い dp[i文字目][Aの数][Tの数][Gの数][Cの数] := A, T, G, Cがそれぞれ(Aの数), (Tの数), (Gの数), (Cの数)だけ出現する文字列が何通りあるか
(i文字目)はその最大値分だけ配列を用意するとMLEするので、2にして使い回すこと
O((Na+Nt+Ng+Nc) * Na * Nt * Ng * Nc)
半分に分けたので2.5 * 10^9くらい
AOJなら3 * 10^9は一瞬なのでOK

最後に、2つの区間について半分全列挙で答えを求める
O(Na * Nt * Ng * Nc)

コード (デバッグコード付き debug = 1にするとデバッグモード) :

#include<bits/stdc++.h>

#define REP(i,s,n) for(int i=s;i<n;++i)
#define rep(i,n) REP(i,0,n)
#define EPS (1e-8)
#define equals(a,b) (fabs((a)-(b))<EPS)
#define pb push_back
#define ALL(x) x.begin(),x.end()

using namespace std;

typedef long long ll;

bool LT(double a,double b) { return !equals(a,b) && a < b; }
bool LTE(double a,double b) { return equals(a,b) || a < b; }

#define A_ 0
#define T_ 1
#define G_ 2
#define C_ 3
#define NONTERMINAL 0
#define TERMINAL 1

struct Data {
  int type; // NONTERMINAL or TERMINAL
  int dst; // for NONTERMINAL
  int bit; // for TERMINAL
};

#define MAX 60
vector<int> N;
int m;
vector<Data> G[MAX];

const ll mod = 1000000007LL;

bool debug = 0;

inline int toIDX(char c) { if( c == 'A' ) return A_; if( c == 'T' ) return T_; if( c == 'G' ) return G_; if( c == 'C' ) return C_; assert(false); }
inline char toCHR(int  i) { if( i == A_ ) return 'A'; if( i == T_ ) return 'T'; if( i == G_ ) return 'G'; if( i == C_ ) return 'C'; assert(false); }


bool fail;
vector<int> chars;
int enum_char_dfs(int cur) {
  if( fail ) return -1;
  int ret = 0;
  rep(i,(int)G[cur].size()) {
    Data &d = G[cur][i];
    if( d.type == NONTERMINAL ) {
      ret += enum_char_dfs(d.dst);
      if( ret >= N[0] + N[1] + N[2] + N[3] + 1 ) { fail = true; return -1; }
    } else {
      chars.pb(d.bit);
      //++len;
      ++ret;
    }
  }
  return ret;
}

int dp[2][52][52][52][52];
int solve_DP(int L,int R) { //[L,R)
  memset(dp,0,sizeof dp);
  dp[0][0][0][0][0] = 1;
  int idx = 0;
  REP(i,L,R) {
    rep(w,N[0]+1) rep(x,N[1]+1) rep(y,N[2]+1) rep(z,N[3]+1) dp[(idx+1)&1][w][x][y][z] = 0;
    rep(a,N[0]+1) {
      rep(t,N[1]+1) {
	rep(g,N[2]+1) {
	  rep(c,N[3]+1) {
	    int &score = dp[idx&1][a][t][g][c];
	    if( score == 0 ) continue;
	    if( (chars[i]>>A_) & 1 ) ( dp[(idx+1)&1][a+1][t][g][c] += score ) %= mod;
	    if( (chars[i]>>T_) & 1 ) ( dp[(idx+1)&1][a][t+1][g][c] += score ) %= mod;
	    if( (chars[i]>>G_) & 1 ) ( dp[(idx+1)&1][a][t][g+1][c] += score ) %= mod;
	    if( (chars[i]>>C_) & 1 ) ( dp[(idx+1)&1][a][t][g][c+1] += score ) %= mod;
	  }
	}
      }
    }
    ++idx;
  }
  return idx;
}

const ll ccoef = 1000000;
const ll gcoef = 10000;
const ll tcoef = 100;
inline ll toHash(int a,int t,int g,int c) { return a + t * tcoef + g * gcoef + c * ccoef; }

void toMap(int idx, map<ll,ll> &mp) {
  rep(a,N[0]+1) {
    rep(t,N[1]+1) {
      rep(g,N[2]+1) {
	rep(c,N[3]+1) {
	  ll hs = toHash(a,t,g,c);
	  ll score = dp[idx][a][t][g][c];
	  if( score == 0 ) continue;
	  mp[hs] = score;
	}
      }
    }
  }
}

void solve() {
  if( debug ) {
    cout << "* Graph ---" << endl;
    rep(i,m) {
      cout << "  - i = " << i << "-th:" << endl;
      rep(j,(int)G[i].size()) {
	Data &d = G[i][j];
	if( d.type == NONTERMINAL ) {
	  cout << "  [NONTERMINAL] : to " << d.dst << endl;
	} else {
	  assert( d.bit != -1 );
	  bitset<4> BIT(d.bit);
	  cout << "  [   TERMINAL] : bit = " << BIT << endl;
	}
      }
      puts("");
    }
  }
  
  // enumerate chars
  chars.clear();
  fail = false;
  int len = enum_char_dfs(0);
  if( fail ) { puts("0"); return ; }
  if( debug ) {
    assert( len == (int)chars.size() );
    cout << "* enumerate chars ---" << endl;
    cout << "  - len = " << len << endl;
    cout << "    ";
    rep(i,len) {
      cout << "[";
      rep(j,4) {
	if( (chars[i]>>j) & 1 ) {
	  cout << toCHR(j);
	}
      }
      cout << "] ";
    }
    puts("");
  }

  // Dynamic Programming!!
  if( len > N[0] + N[1] + N[2] + N[3] ) { puts("0"); return ; }
  int sz = len / 2;
  if( debug ) {
    cout << "[SPLIT] -- [" << 0 << "," << sz << ") and [" << sz << "," << len << ")" << endl;
  }
  map<ll,ll> mp[2];
  int lst;
  lst = solve_DP(0,sz);
  toMap(lst&1,mp[0]);



  if( debug ) {
    cout << "* mp[0] Info ---" << endl;
    for(auto v : mp[0]) {
      ll hs = v.first;
      int c = hs / ccoef;
      hs -= ( c * ccoef );
      int g = hs / gcoef;
      hs -= ( g * gcoef );
      int t = hs / tcoef;
      hs -= ( t * tcoef );
      int a = hs;
      ll score = v.second;
      cout << "  (A,T,G,C) = (" << a << "," << t << "," << g << "," << c << ")" << endl;
      cout << "      score = " << score << endl;
    }
  }

  lst = solve_DP(sz,len);
  toMap(lst&1,mp[1]);
  if( debug ) {
    cout << "* mp[1] Info ---" << endl;
    for(auto v : mp[1]) {
      ll hs = v.first;
      int c = hs / ccoef;
      hs -= ( c * ccoef );
      int g = hs / gcoef;
      hs -= ( g * gcoef );
      int t = hs / tcoef;
      hs -= ( t * tcoef );
      int a = hs;
      ll score = v.second;
      cout << "  (A,T,G,C) = (" << a << "," << t << "," << g << "," << c << ")" << endl;
      cout << "      score = " << score << endl;
    }
  }

  // hanbun-zennrekkyo
  ll answer = 0;
  for(auto v : mp[0]) {
    ll hs = v.first;
    int c = hs / ccoef;
    hs -= ( c * ccoef );
    int g = hs / gcoef;
    hs -= ( g * gcoef );
    int t = hs / tcoef;
    hs -= ( t * tcoef );
    int a = hs;
    ll score = v.second;
    if( a > N[0] || t > N[1] || g > N[2] || c > N[3] ) continue;
    int rem_a = N[0] - a;
    int rem_t = N[1] - t;
    int rem_g = N[2] - g;
    int rem_c = N[3] - c;
    ll next_hs = toHash(rem_a,rem_t,rem_g,rem_c);
    if( mp[1].count(next_hs) ) {
      ( answer += ( ( score * mp[1][next_hs] ) % mod ) ) %= mod;
    }
  }
  cout << answer << endl;
}

void parse(string s, map<string,int> &mp, int &id) {
  rep(i,(int)s.size()) if( s[i] == ':' ) { s[i] = ' '; break; }
  stringstream ss;
  ss << s;
  vector<string> vec;
  while( ss >> s ) vec.pb(s);
  if( !mp.count(vec[0]) ) { mp[vec[0]] = id++; }
  int cur = mp[vec[0]];
  REP(i,1,(int)vec.size()) {
    string &t = vec[i];
    if( t[0] == '[' ) {
      int ptr = 1, bit = 0;
      while( ptr < (int)t.size() && t[ptr] != ']' ) bit = bit | (1<<toIDX(t[ptr++]));
      G[cur].pb((Data){TERMINAL,-1,bit});
    } else {
      if( !mp.count(vec[i]) ) { mp[vec[i]] = id++; }
      int nex = mp[vec[i]];
      G[cur].pb((Data){NONTERMINAL,nex,-1});
    }
  }
}

int main() {
  N.resize(4);
  rep(i,4) cin >> N[i];
  cin >> m;
  cin.ignore();
  string line;
  map<string,int> mp;
  int id=0;
  rep(i,m) {
    getline(cin,line);
    parse(line,mp,id);
  }
  solve();
  return 0;
}