Facebook Hacker Cup2015 Round1 25:Autocomplete

Facebook Hacker Cup2015 Round1に参加しました.
結果は10,25,25の60点ということでRound2には進出できませんでした orz

公式の想定解法は Hacker Cup 2015 Round 1 Solutions に書いています.25:Autocompleteの想定解法はTrie木のようです.
他の方針で解いたのでメモしときます(Tire木を学習すると言ってまだ出来てないので,もしかしたら本質的には同じだったりして).

[25 "Autocomplete"]

問題文

あなたはオートコンプリート機能付きの携帯を買いました.あなたが何を書きたいかが正確に分かる時だけオートコンプリートします.また,あなたがオートコンプリート機能を使いたいすべての単語を辞書に登録する必要があります.
あなたは順番に送信したい異なるN個の単語を持っています.それぞれの単語を送る前に,送信する単語を携帯の辞書に登録します.そして,あなたは送信する単語を携帯がオートコンプリート出来る最小の空でない接頭辞を書きます.その接頭辞は単語全体か,または,辞書に登録されている任意の接頭辞とは異なるとします.
N個の単語をすべてを送信する時のあなたが携帯に打つ最小の文字数はいくつでしょうか.

 1 \le N \le 100,000 ,N個の単語の総文字数は1,000,000以下だとします.

解法

送信する単語をsとして,単語sを登録する前に辞書に含まれている単語の集合をDとします.sを送信する時に携帯に打たなければならない文字数を考えます.愚直な方法としては,Dに含まれているすべての単語との共通部分を調べます.この時共通部分が最も長い単語がsを送信する時に携帯に打つ文字数を決定します.しかし,この愚直な方法だと間に合いません.
ここで,比較するために辞書の中から選ぶ単語を少なく出来ないかを考えます.携帯に打つ文字数を決定する単語は辞書式順序でsに近い単語となっています.したがって,辞書Dをソートして単語sを挿入した時に前後に含まれている単語のみを調べれば良いことになります.

実装では辞書Dをstd::setとしました.std::set内では単語が辞書式順序でソートされています.
挿入した時に返される値は pair<iterator,bool> でiteratorは挿入した要素を指すイテレータで,boolは挿入が行われたかを判断する真偽値です.iteratorを使用して前後の単語を探しています.
計算量は大雑把に  O(N \log N + 1,000,000) です(O記法に定数を入れてすみません).

具体例

Nを5として,送信する単語を順番に

  • hi
  • hello
  • lol
  • hills
  • hill

とします.
辞書をDとして初めは空です.sを送信する単語,sの前の単語をs1,sの後ろの単語をs2とします.

[1番目 s = "hi"]
D <- "hi"
D = {"hi"}
s1 = "", s2 = ""
携帯に打つ文字数 = 1

[2番目 s = "hello"]
D <- "hello"
D = {"hello", "hi"}
s1 = "", s2 = "hi"
携帯に打つ文字数 = 2

[3番目 s = "lol"]
D <- "lol"
D = {"hello", "hi", "lol"}
s1 = "hi", s2 = ""
携帯に打つ文字数 = 1

[4番目 s = "hills"]
D <- "hills"
D = {"hello", "hi", "hills", "lol"}
s1 = "hi", s2 = "lol"
携帯に打つ文字数 = 3

[5番目 s = "hill"]
D <- "hill"
D = {"hello", "hi", "hill", "hills", "lol"}
s1 = "hi", s2 = "hills"
携帯に打つ文字数 = 4

答えは 1 + 2 + 1 + 3 + 4 = 11

ソースコード
#include <bits/stdc++.h>

using namespace std;

typedef long long  ll;

// 送信する単語をb, 登録されている単語をaとした時に打つ文字数
int Count(string a, string b)
{
    int res = 1, n = min(a.size(), b.size());

    for (int i = 0; i < n; ++i) {
        if (a[i] != b[i])
            break;
        ++res;
    }

    return min((int)b.size(), res);
}

int Solve()
{
    int N, res = 1;
    string s1, s2, s;
    set<string> memo;

    cin >> N;
    cin >> s1;

    if (N == 1)
        return res;

    cin >> s2;
    res += Count(s1, s2);
    if (N == 2)
        return res;

    memo.insert(s1);
    memo.insert(s2);

    for (int i = 2; i < N; ++i) {
        int tmp1 = 0, tmp2 = 0;
        cin >> s;
        auto it = (memo.insert(s)).first;

        if (it != memo.begin()) {
            advance(it, -1);
            s1 = *it;

            tmp1 = Count(s1, s);

            advance(it, 1);
        }
        if (++it != memo.end()) {
            s2 = *it;
            tmp2 = Count(s2, s);
        }

        res += max(tmp1, tmp2);
    }

    return res;
}

int main()
{
    cin.tie(0);
    ios::sync_with_stdio(false);

    int T;
    cin >> T;

    for (int i = 1; i <= T; ++i) {
        cout << "Case #" << i << ": " << Solve() << "\n";
    }

    return 0;
}