使用地图的问题是你失去了位置。除了缩小可能的键之外,拥有一个实际的字母表并没有任何好处。
一次解析一个字符会在内存中跳跃。如果字符串一次解析三个字符,那么对于给定字符串的序列加载以及在其保留在缓存中时廉价加载其邻居的机会都会有更高的局部性。
这也利用了一些后来的语言特性,其中大部分可以通过手动加载 char<256> 轻松完成。可能还有更好的方法来做到这一点。
#include <map>
#include <array>
#include <string>
#include <cstddef>
#include <cstdint>
#include <iostream>
// MAP
template<char ... Args>
class mkmap;
// Walk the tuple
template<char count, char tuple_sz, char tuple_cnt, char grab, char ... alpha>
class mkmap<count, tuple_sz, tuple_cnt, grab, alpha...> {
public:
static constexpr void map(std::array<char, 256>& map) {
map[grab] = count;
mkmap<count, tuple_sz, tuple_cnt - 1, alpha...>::map(map);
}
};
// Next tuple
template<char count, char tuple_sz, char grab, char ... alpha>
class mkmap<count, tuple_sz, 1, grab, alpha...> {
public:
static constexpr void map(std::array<char, 256>& map) {
map[grab] = count;;
mkmap<count + 1, tuple_sz, tuple_sz, alpha...>::map(map);
}
};
// End recursion
template<char count, char tuple_sz, char tuple_cnt>
class mkmap<count, tuple_sz, tuple_cnt> {
public:
static constexpr void map(std::array<char, 256>& map) {
}
};
template<int tuple_sz, char ... alpha>
class cvtmap {
public:
constexpr cvtmap() : map{} {
mkmap<1, tuple_sz, tuple_sz, alpha...>::map(map);
}
constexpr char operator[](char input) const {
return map[input];
}
std::array<char, 256> map;
};
// UNMAP
template<char ... Args>
class mkunmap;
// Walk the tuple
template<char count, char tuple_sz, char tuple_cnt, char grab, char ... alpha>
class mkunmap<count, tuple_sz, tuple_cnt, grab, alpha...> {
public:
static constexpr void map(std::array<std::array<char, tuple_sz>, 256>& map) {
map[count][tuple_sz - tuple_cnt] = grab;
mkunmap<count, tuple_sz, tuple_cnt - 1, alpha...>::map(map);
}
};
// Next tuple
template<char count, char tuple_sz, char grab, char ... alpha>
class mkunmap<count, tuple_sz, 1, grab, alpha...> {
public:
static constexpr void map(std::array<std::array<char, tuple_sz>, 256>& map) {
map[count][tuple_sz - 1] = grab;;
mkunmap<count + 1, tuple_sz, tuple_sz, alpha...>::map(map);
}
};
// End recursion
template<char count, char tuple_sz, char tuple_cnt>
class mkunmap<count, tuple_sz, tuple_cnt> {
public:
static constexpr void map(std::array<std::array<char, tuple_sz>, 256>& map) {
}
};
template<int tuple_sz, char ... alpha>
class cvtunmap {
public:
constexpr cvtunmap() : map{} {
mkunmap<1, tuple_sz, tuple_sz, alpha...>::map(map);
}
constexpr std::array<char, tuple_sz> operator[](char input) const {
return map[input];
}
std::array<std::array<char, tuple_sz>, 256> map;
};
template<int tuple_sz, char ... alpha>
class cvt
{
public:
enum consts : char { SENTINAL = 0 };
static constexpr int size() { return sizeof...(alpha) / tuple_sz + 1; }
cvt(char c) : a{ map[c] } {
}
char to_char()
{
return unmap[a][0];
}
unsigned short value() const {
return a;
}
private:
char a;
static const cvtmap<tuple_sz, alpha...> map;
static const cvtunmap<tuple_sz, alpha...> unmap;
};
template<int tuple_sz, char ... alpha>
const cvtmap<tuple_sz, alpha...> cvt<tuple_sz, alpha...>::map;
template<int tuple_sz, char ... alpha>
const cvtunmap<tuple_sz, alpha...> cvt<tuple_sz, alpha...>::unmap;
using ASCII_ignore_case = cvt <2,
'a', 'A', 'b', 'B', 'c', 'C', 'd', 'D', 'e', 'E', 'f', 'F', 'g', 'G', 'h', 'H', 'i', 'I', 'j', 'J', 'k', 'K', 'l', 'L', 'm', 'M', 'n', 'N', 'o', 'O', 'p', 'P', 'q', 'Q', 'r', 'R', 's', 'S', 't', 'T', 'u', 'U', 'v', 'V', 'w', 'W', 'x', 'X', 'y', 'Y', 'z', 'Z'
>;
template <class alphabet>
class Node {
public:
enum consts { SHIFT = 32 };
static short make_key(alphabet a, alphabet b, alphabet c) {
// max is Z (26) * 27 * 27 == 18954 which fits under SHRT_MAX (32767)
return
a.value() * SHIFT * SHIFT
+ b.value() * SHIFT
+ c.value();
}
static std::array<char, 3> to_string(short x) {
char a = (x / (SHIFT * SHIFT)) & 0xFF;
x -= a * SHIFT * SHIFT;
char b = (x / SHIFT) &0xFF;
x -= b * SHIFT;
char c = x &0xFF;
return { a,b,c };
}
Node* add(short key) {
if (idx.contains(key)) {
return idx[key];
}
Node* ret = new Node;
idx[key] = ret;
return ret;
}
static Node* sentinal() {
static Node fixed;
return &fixed;
}
void add_final(short key) {
if (!idx.contains(key)) {
idx[key] = sentinal();
}
}
const Node* get(short key) const {
auto it = idx.find(key);
if (it != idx.end()) { // avoid creating nodes
return it->second;
}
return 0;
}
bool is_final(short key) const {
auto it = idx.find(key);
if (it != idx.end()) { // avoid creating nodes
return it->second == sentinal();
}
return false;
}
~Node() = default;
private:
std::map <short, Node*> idx;
};
template <class alphabet>
class TriTrie {
public:
void add(std::string& str) {
std::string::iterator i = str.begin();
const std::string::iterator e = str.end();
Node<alphabet>* where = ⊤
for (;;) {
std::size_t len = e - i;
alphabet a = alphabet::SENTINAL;
alphabet b = alphabet::SENTINAL;
alphabet c = alphabet::SENTINAL;
switch (len)
{
default: [[likely]] {
a = alphabet(*(i++));
b = alphabet(*(i++));
c = alphabet(*(i++));
short key = Node<alphabet>::make_key(a,b,c);
where = where->add(key);
}
break;
case 3:
c = alphabet(*(i + 2));
[[fallthrough]];
case 2:
b = alphabet(*(i + 1));
[[fallthrough]];
case 1:
a = alphabet(*i);
[[fallthrough]];
case 0: {
short key = Node<alphabet>::make_key(a, b, c);
where->add_final(key);
return;
}
}
}
}
bool contains(std::string& str) const {
std::string::iterator i = str.begin();
const std::string::iterator e = str.end();
const Node<alphabet>* where = ⊤
while (where) {
std::size_t len = e - i;
alphabet a = alphabet::SENTINAL;
alphabet b = alphabet::SENTINAL;
alphabet c = alphabet::SENTINAL;
switch (len) {
default: [[likely]] {
a = alphabet(*(i++));
b = alphabet(*(i++));
c = alphabet(*(i++));
short key = Node<alphabet>::make_key(a,b,c);
where = where->get(key);
}
break;
case 3:
c = alphabet(*(i + 2));
[[fallthrough]];
case 2:
b = alphabet(*(i + 1));
[[fallthrough]];
case 1:
a = alphabet(*i);
[[fallthrough]];
case 0: {
short key = Node<alphabet>::make_key(a, b, c);
return where->is_final(key);
}
}
}
return false;
}
private:
Node<alphabet> top;
};
using ASCII_TriTrie = TriTrie<ASCII_ignore_case>;
int main()
{
ASCII_TriTrie tt;
for (std::string s : {
"hello", "goodbye", "big", "little", "hi", "hit", "hitch", "him"
}) {
std::cout << s << ":" << std::endl;
if (tt.contains(s)) {
return -1;
}
tt.add(s);
if (!tt.contains(s)) {
return -2;
}
}
return 0;
}