45 #include <sphinxbase/priority_queue.h>
48 #include "lm_trie_quant.h"
50 static void lm_trie_alloc_ngram(
lm_trie_t * trie, uint32 * counts,
int order);
53 base_size(uint32 entries, uint32 max_vocab, uint8 remaining_bits)
60 return ((1 + entries) * total_bits + 7) / 8 +
sizeof(uint64);
64 middle_size(uint8 quant_bits, uint32 entries, uint32 max_vocab,
67 return base_size(entries, max_vocab,
72 longest_size(uint8 quant_bits, uint32 entries, uint32 max_vocab)
74 return base_size(entries, max_vocab, quant_bits);
78 base_init(
base_t * base,
void *base_mem, uint32 max_vocab,
82 base->word_mask = (1
U << base->word_bits) - 1
U;
83 if (base->word_bits > 25)
85 (
"Sorry, word indices more than %d are not implemented. Edit util/bit_packing.hh and fix the bit packing functions\n",
87 base->total_bits = base->word_bits + remaining_bits;
89 base->base = (uint8 *) base_mem;
90 base->insert_index = 0;
91 base->max_vocab = max_vocab;
95 middle_init(
middle_t * middle,
void *base_mem, uint8 quant_bits,
96 uint32 entries, uint32 max_vocab, uint32 max_next,
99 middle->quant_bits = quant_bits;
101 middle->next_source = next_source;
102 if (entries + 1 >= (1
U << 25) || (max_next >= (1
U << 25)))
104 (
"Sorry, this does not support more than %d n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions\n",
106 base_init(&middle->base, base_mem, max_vocab,
107 quant_bits + middle->next_mask.bits);
111 longest_init(
longest_t * longest,
void *base_mem, uint8 quant_bits,
114 base_init(&longest->base, base_mem, max_vocab, quant_bits);
118 middle_insert(
middle_t * middle, uint32 word,
int order,
int max_order)
123 assert(word <= middle->base.word_mask);
124 address.base = middle->base.base;
125 address.offset = middle->base.insert_index * middle->base.total_bits;
127 address.offset += middle->base.word_bits;
128 at_pointer = address.offset;
129 address.offset += middle->quant_bits;
130 if (order == max_order - 1) {
131 next = ((
longest_t *) middle->next_source)->base.insert_index;
134 next = ((
middle_t *) middle->next_source)->base.insert_index;
138 middle->base.insert_index++;
139 address.offset = at_pointer;
144 longest_insert(
longest_t * longest, uint32 index)
147 assert(index <= longest->base.word_mask);
148 address.base = longest->base.base;
149 address.offset = longest->base.insert_index * longest->base.total_bits;
151 address.offset += longest->base.word_bits;
152 longest->base.insert_index++;
157 middle_finish_loading(
middle_t * middle, uint32 next_end)
160 address.base = middle->base.base;
162 (middle->base.insert_index + 1) * middle->base.total_bits -
163 middle->next_mask.bits;
168 unigram_next(
lm_trie_t * trie,
int order)
171 2 ? trie->longest->base.insert_index : trie->middle_begin->base.
176 lm_trie_fix_counts(
ngram_raw_t ** raw_ngrams, uint32 * counts,
177 uint32 * fixed_counts,
int order)
180 priority_queue_create(order - 1, &ngram_ord_comparator);
181 uint32 raw_ngram_ptrs[NGRAM_MAX_ORDER - 1];
182 uint32 words[NGRAM_MAX_ORDER];
185 memset(words, -1,
sizeof(words));
186 memcpy(fixed_counts, counts, order *
sizeof(*fixed_counts));
187 for (i = 2; i <= order; i++) {
190 if (counts[i - 1] <= 0)
193 raw_ngram_ptrs[i - 2] = 0;
197 *tmp_ngram = raw_ngrams[i - 2][0];
198 tmp_ngram->order = i;
199 priority_queue_add(ngrams, tmp_ngram);
203 int32 to_increment = TRUE;
205 if (priority_queue_size(ngrams) == 0) {
209 if (top->order == 2) {
210 memcpy(words, top->words, 2 *
sizeof(*words));
213 for (i = 0; i < top->order - 1; i++) {
214 if (words[i] != top->words[i]) {
216 num = (i == 0) ? 1 : i;
217 memcpy(words, top->words,
218 (num + 1) *
sizeof(*words));
220 to_increment = FALSE;
224 words[top->order - 1] = top->words[top->order - 1];
227 raw_ngram_ptrs[top->order - 2]++;
229 if (raw_ngram_ptrs[top->order - 2] < counts[top->order - 1]) {
230 *top = raw_ngrams[top->order - 2][raw_ngram_ptrs[top->order - 2]];
231 priority_queue_add(ngrams, top);
238 assert(priority_queue_size(ngrams) == 0);
239 priority_queue_free(ngrams, NULL);
245 uint32 * counts,
int order)
247 uint32 unigram_idx = 0;
250 const uint32 unigram_count = (uint32) counts[0];
252 priority_queue_create(order, &ngram_ord_comparator);
254 uint32 *raw_ngrams_ptr;
257 words = (uint32 *)
ckd_calloc(order,
sizeof(*words));
258 probs = (
float *)
ckd_calloc(order - 1,
sizeof(*probs));
261 ngram->words = &unigram_idx;
262 priority_queue_add(ngrams, ngram);
264 (uint32 *)
ckd_calloc(order - 1,
sizeof(*raw_ngrams_ptr));
265 for (i = 2; i <= order; ++i) {
268 if (counts[i - 1] <= 0)
271 raw_ngrams_ptr[i - 2] = 0;
274 *tmp_ngram = raw_ngrams[i - 2][0];
275 tmp_ngram->order = i;
277 priority_queue_add(ngrams, tmp_ngram);
284 if (top->order == 1) {
285 trie->unigrams[unigram_idx].next = unigram_next(trie, order);
286 words[0] = unigram_idx;
287 probs[0] = trie->unigrams[unigram_idx].prob;
288 if (++unigram_idx == unigram_count + 1) {
292 priority_queue_add(ngrams, top);
295 for (i = 0; i < top->order - 1; i++) {
296 if (words[i] != top->words[i]) {
300 for (j = i; j < top->order - 1; j++) {
301 middle_t *middle = &trie->middle_begin[j - 1];
303 middle_insert(middle, top->words[j],
308 trie->unigrams[top->words[j]].bo;
309 probs[j] = calc_prob;
310 lm_trie_quant_mwrite(trie->quant, address, j - 1,
315 memcpy(words, top->words,
316 top->order *
sizeof(*words));
317 if (top->order == order) {
319 longest_insert(trie->longest,
320 top->words[top->order - 1]);
321 lm_trie_quant_lwrite(trie->quant, address, top->prob);
324 middle_t *middle = &trie->middle_begin[top->order - 2];
326 middle_insert(middle,
327 top->words[top->order - 1],
330 probs[top->order - 1] = top->prob;
331 lm_trie_quant_mwrite(trie->quant, address, top->order - 2,
332 top->prob, top->backoff);
334 raw_ngrams_ptr[top->order - 2]++;
335 if (raw_ngrams_ptr[top->order - 2] < counts[top->order - 1]) {
336 *top = raw_ngrams[top->order -
337 2][raw_ngrams_ptr[top->order - 2]];
339 priority_queue_add(ngrams, top);
346 assert(priority_queue_size(ngrams) == 0);
347 priority_queue_free(ngrams, NULL);
354 lm_trie_init(uint32 unigram_count)
359 memset(trie->hist_cache, -1,
sizeof(trie->hist_cache));
360 memset(trie->backoff_cache, 0,
sizeof(trie->backoff_cache));
363 sizeof(*trie->unigrams));
364 trie->ngram_mem = NULL;
369 lm_trie_create(uint32 unigram_count,
int order)
371 lm_trie_t *trie = lm_trie_init(unigram_count);
373 (order > 1) ? lm_trie_quant_create(order) : 0;
378 lm_trie_read_bin(uint32 * counts,
int order, FILE * fp)
380 lm_trie_t *trie = lm_trie_init(counts[0]);
381 trie->quant = (order > 1) ? lm_trie_quant_read_bin(fp, order) : NULL;
382 fread(trie->unigrams,
sizeof(*trie->unigrams), (counts[0] + 1), fp);
384 lm_trie_alloc_ngram(trie, counts, order);
385 fread(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
391 lm_trie_write_bin(
lm_trie_t * trie, uint32 unigram_count, FILE * fp)
395 lm_trie_quant_write_bin(trie->quant, fp);
396 fwrite(trie->unigrams,
sizeof(*trie->unigrams), (unigram_count + 1),
399 fwrite(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
405 if (trie->ngram_mem) {
411 lm_trie_quant_free(trie->quant);
417 lm_trie_alloc_ngram(
lm_trie_t * trie, uint32 * counts,
int order)
421 uint8 **middle_starts;
423 trie->ngram_mem_size = 0;
424 for (i = 1; i < order - 1; i++) {
425 trie->ngram_mem_size +=
426 middle_size(lm_trie_quant_msize(trie->quant), counts[i],
427 counts[0], counts[i + 1]);
429 trie->ngram_mem_size +=
430 longest_size(lm_trie_quant_lsize(trie->quant), counts[order - 1],
434 sizeof(*trie->ngram_mem));
435 mem_ptr = trie->ngram_mem;
438 trie->middle_end = trie->middle_begin + (order - 2);
440 (uint8 **)
ckd_calloc(order - 2,
sizeof(*middle_starts));
441 for (i = 2; i < order; i++) {
442 middle_starts[i - 2] = mem_ptr;
444 middle_size(lm_trie_quant_msize(trie->quant), counts[i - 1],
445 counts[0], counts[i]);
449 for (i = order - 1; i >= 2; --i) {
450 middle_t *middle_ptr = &trie->middle_begin[i - 2];
451 middle_init(middle_ptr, middle_starts[i - 2],
452 lm_trie_quant_msize(trie->quant), counts[i - 1],
453 counts[0], counts[i],
456 1) ? (
void *) trie->longest : (
void *) &trie->
457 middle_begin[i - 1]);
460 longest_init(trie->longest, mem_ptr, lm_trie_quant_lsize(trie->quant),
470 lm_trie_fix_counts(raw_ngrams, counts, out_counts, order);
471 lm_trie_alloc_ngram(trie, out_counts, order);
474 E_INFO(
"Training quantizer\n");
475 for (i = 2; i < order; i++) {
476 lm_trie_quant_train(trie->quant, i, counts[i - 1],
479 lm_trie_quant_train_prob(trie->quant, order, counts[order - 1],
480 raw_ngrams[order - 2]);
482 E_INFO(
"Building LM trie\n");
483 recursive_insert(trie, raw_ngrams, counts, order);
486 if (trie->middle_begin != trie->middle_end) {
488 for (middle_ptr = trie->middle_begin;
489 middle_ptr != trie->middle_end - 1; ++middle_ptr) {
490 middle_t *next_middle_ptr = middle_ptr + 1;
491 middle_finish_loading(middle_ptr,
492 next_middle_ptr->base.insert_index);
494 middle_ptr = trie->middle_end - 1;
495 middle_finish_loading(middle_ptr,
496 trie->longest->base.insert_index);
504 next->begin = ptr->next;
505 next->end = (ptr + 1)->next;
510 calc_pivot(uint32 off, uint32 range, uint32 width)
512 return (
size_t) ((off * width) / (range + 1));
516 uniform_find(
void *base, uint8 total_bits, uint8 key_bits, uint32 key_mask,
517 uint32 before_it, uint32 before_v,
518 uint32 after_it, uint32 after_v, uint32 key, uint32 * out)
522 while (after_it - before_it > 1) {
526 calc_pivot(key - before_v, after_v - before_v,
527 after_it - before_it - 1));
529 address.offset = pivot * (uint32) total_bits;
535 else if (mid > key) {
555 ((
void *) middle->base.base, middle->base.total_bits,
556 middle->base.word_bits, middle->base.word_mask, range->begin - 1,
557 0, range->end, middle->base.max_vocab, word, &at_pointer)) {
563 address.base = middle->base.base;
564 at_pointer *= middle->base.total_bits;
565 at_pointer += middle->base.word_bits;
566 address.offset = at_pointer + middle->quant_bits;
569 middle->next_mask.mask);
570 address.offset += middle->base.total_bits;
573 middle->next_mask.mask);
574 address.offset = at_pointer;
587 ((
void *) longest->base.base, longest->base.total_bits,
588 longest->base.word_bits, longest->base.word_mask,
589 range->begin - 1, 0, range->end, longest->base.max_vocab, word,
595 address.base = longest->base.base;
597 at_pointer * longest->base.total_bits + longest->base.word_bits;
602 get_available_prob(
lm_trie_t * trie, int32 wid, int32 * hist,
603 int max_order, int32 n_hist, int32 * n_used)
609 uint8 independent_left;
610 int32 *hist_iter, *hist_end;
613 prob = unigram_find(trie->unigrams, wid, &node)->prob;
620 independent_left = (node.begin == node.end);
622 hist_end = hist + n_hist;
623 for (;; order_minus_2++, hist_iter++) {
624 if (hist_iter == hist_end)
626 if (independent_left)
628 if (order_minus_2 == max_order - 2)
632 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
634 independent_left = (address.base == NULL)
635 || (node.begin == node.end);
638 if (address.base == NULL)
640 prob = lm_trie_quant_mpread(trie->quant, address, order_minus_2);
641 *n_used = order_minus_2 + 2;
644 address = longest_find(trie->longest, *hist_iter, &node);
645 if (address.base != NULL) {
646 prob = lm_trie_quant_lpread(trie->quant, address);
653 get_available_backoff(
lm_trie_t * trie, int32 start, int32 * hist,
656 float backoff = 0.0f;
660 unigram_t *first_hist = unigram_find(trie->unigrams, hist[0], &node);
662 backoff += first_hist->bo;
665 order_minus_2 = start - 2;
666 for (hist_iter = hist + start - 1; hist_iter < hist + n_hist;
667 hist_iter++, order_minus_2++) {
669 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
671 if (address.base == NULL)
674 lm_trie_quant_mboread(trie->quant, address, order_minus_2);
680 lm_trie_nobo_score(
lm_trie_t * trie, int32 wid, int32 * hist,
681 int max_order, int32 n_hist, int32 * n_used)
684 get_available_prob(trie, wid, hist, max_order, n_hist, n_used);
685 if (n_hist < *n_used)
687 return prob + get_available_backoff(trie, *n_used, hist, n_hist);
691 lm_trie_hist_score(
lm_trie_t * trie, int32 wid, int32 * hist, int32 n_hist,
700 prob = unigram_find(trie->unigrams, wid, &node)->prob;
703 for (i = 0; i < n_hist - 1; i++) {
704 address = middle_find(&trie->middle_begin[i], hist[i], &node);
705 if (address.base == NULL) {
706 for (j = i; j < n_hist; j++) {
707 prob += trie->backoff_cache[j];
713 prob = lm_trie_quant_mpread(trie->quant, address, i);
716 address = longest_find(trie->longest, hist[n_hist - 1], &node);
717 if (address.base == NULL) {
718 return prob + trie->backoff_cache[n_hist - 1];
722 return lm_trie_quant_lpread(trie->quant, address);
727 history_matches(int32 * hist, int32 * prev_hist, int32 n_hist)
730 for (i = 0; i < n_hist; i++) {
731 if (hist[i] != prev_hist[i]) {
739 update_backoff(
lm_trie_t * trie, int32 * hist, int32 n_hist)
745 memset(trie->backoff_cache, 0,
sizeof(trie->backoff_cache));
746 trie->backoff_cache[0] = unigram_find(trie->unigrams, hist[0], &node)->bo;
747 for (i = 1; i < n_hist; i++) {
748 address = middle_find(&trie->middle_begin[i - 1], hist[i], &node);
749 if (address.base == NULL) {
752 trie->backoff_cache[i] =
753 lm_trie_quant_mboread(trie->quant, address, i - 1);
755 memcpy(trie->hist_cache, hist, n_hist *
sizeof(*hist));
759 lm_trie_score(
lm_trie_t * trie,
int order, int32 wid, int32 * hist,
760 int32 n_hist, int32 * n_used)
762 if (n_hist < order - 1) {
763 return lm_trie_nobo_score(trie, wid, hist, order, n_hist, n_used);
766 assert(n_hist == order - 1);
767 if (!history_matches(hist, (int32 *) trie->hist_cache, n_hist)) {
768 update_backoff(trie, hist, n_hist);
770 return lm_trie_hist_score(trie, wid, hist, n_hist, n_used);
778 int n_hist,
int order,
int max_order)
780 if (n_hist > 0 && range.begin == range.end) {
785 for (i = 0; i < counts[0]; i++) {
787 unigram_find(trie->unigrams, i, &node);
789 lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
790 node, hist, 1, order, max_order);
793 else if (n_hist < order - 1) {
798 middle_t *middle = &trie->middle_begin[n_hist - 1];
799 for (ptr = range.begin; ptr < range.end; ptr++) {
800 address.base = middle->base.base;
801 address.offset = ptr * middle->base.total_bits;
804 middle->base.word_mask);
805 hist[n_hist] = new_word;
806 address.offset += middle->base.word_bits + middle->quant_bits;
809 middle->next_mask.mask);
811 (ptr + 1) * middle->base.total_bits +
812 middle->base.word_bits + middle->quant_bits;
815 middle->next_mask.mask);
816 lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
817 node, hist, n_hist + 1, order, max_order);
825 assert(n_hist == order - 1);
826 for (ptr = range.begin; ptr < range.end; ptr++) {
827 ngram_raw_t *raw_ngram = &raw_ngrams[*raw_ngram_idx];
828 if (order == max_order) {
830 address.base = longest->base.base;
831 address.offset = ptr * longest->base.total_bits;
834 longest->base.word_mask);
835 address.offset += longest->base.word_bits;
836 prob = lm_trie_quant_lpread(trie->quant, address);
839 middle_t *middle = &trie->middle_begin[n_hist - 1];
840 address.base = middle->base.base;
841 address.offset = ptr * middle->base.total_bits;
844 middle->base.word_mask);
845 address.offset += middle->base.word_bits;
847 lm_trie_quant_mpread(trie->quant, address, n_hist - 1);
849 lm_trie_quant_mboread(trie->quant, address,
851 raw_ngram->backoff = backoff;
853 raw_ngram->prob = prob;
855 (uint32 *)
ckd_calloc(order,
sizeof(*raw_ngram->words));
856 for (i = 0; i <= n_hist; i++) {
857 raw_ngram->words[i] = hist[n_hist - i];