SphinxBase  5prealpha
lm_trie.c
1 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2 /* ====================================================================
3  * Copyright (c) 2015 Carnegie Mellon University. All rights
4  * reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions
8  * are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright
11  * notice, this list of conditions and the following disclaimer.
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright
14  * notice, this list of conditions and the following disclaimer in
15  * the documentation and/or other materials provided with the
16  * distribution.
17  *
18  * This work was supported in part by funding from the Defense Advanced
19  * Research Projects Agency and the National Science Foundation of the
20  * United States of America, and the CMU Sphinx Speech Consortium.
21  *
22  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  *
34  * ====================================================================
35  *
36  */
37 
38 #include <string.h>
39 #include <stdio.h>
40 #include <assert.h>
41 
42 #include <sphinxbase/prim_type.h>
43 #include <sphinxbase/ckd_alloc.h>
44 #include <sphinxbase/err.h>
45 #include <sphinxbase/priority_queue.h>
46 
47 #include "lm_trie.h"
48 #include "lm_trie_quant.h"
49 
50 static void lm_trie_alloc_ngram(lm_trie_t * trie, uint32 * counts, int order);
51 
52 static uint32
53 base_size(uint32 entries, uint32 max_vocab, uint8 remaining_bits)
54 {
55  uint8 total_bits = bitarr_required_bits(max_vocab) + remaining_bits;
56  /* Extra entry for next pointer at the end.
57  * +7 then / 8 to round up bits and convert to bytes
58  * +sizeof(uint64) so that ReadInt57 etc don't go segfault.
59  * Note that this waste is O(order), not O(number of ngrams).*/
60  return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64);
61 }
62 
63 uint32
64 middle_size(uint8 quant_bits, uint32 entries, uint32 max_vocab,
65  uint32 max_ptr)
66 {
67  return base_size(entries, max_vocab,
68  quant_bits + bitarr_required_bits(max_ptr));
69 }
70 
71 uint32
72 longest_size(uint8 quant_bits, uint32 entries, uint32 max_vocab)
73 {
74  return base_size(entries, max_vocab, quant_bits);
75 }
76 
77 static void
78 base_init(base_t * base, void *base_mem, uint32 max_vocab,
79  uint8 remaining_bits)
80 {
81  base->word_bits = bitarr_required_bits(max_vocab);
82  base->word_mask = (1U << base->word_bits) - 1U;
83  if (base->word_bits > 25)
84  E_ERROR
85  ("Sorry, word indices more than %d are not implemented. Edit util/bit_packing.hh and fix the bit packing functions\n",
86  (1U << 25));
87  base->total_bits = base->word_bits + remaining_bits;
88 
89  base->base = (uint8 *) base_mem;
90  base->insert_index = 0;
91  base->max_vocab = max_vocab;
92 }
93 
94 void
95 middle_init(middle_t * middle, void *base_mem, uint8 quant_bits,
96  uint32 entries, uint32 max_vocab, uint32 max_next,
97  void *next_source)
98 {
99  middle->quant_bits = quant_bits;
100  bitarr_mask_from_max(&middle->next_mask, max_next);
101  middle->next_source = next_source;
102  if (entries + 1 >= (1U << 25) || (max_next >= (1U << 25)))
103  E_ERROR
104  ("Sorry, this does not support more than %d n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions\n",
105  (1U << 25));
106  base_init(&middle->base, base_mem, max_vocab,
107  quant_bits + middle->next_mask.bits);
108 }
109 
110 void
111 longest_init(longest_t * longest, void *base_mem, uint8 quant_bits,
112  uint32 max_vocab)
113 {
114  base_init(&longest->base, base_mem, max_vocab, quant_bits);
115 }
116 
117 static bitarr_address_t
118 middle_insert(middle_t * middle, uint32 word, int order, int max_order)
119 {
120  uint32 at_pointer;
121  uint32 next;
122  bitarr_address_t address;
123  assert(word <= middle->base.word_mask);
124  address.base = middle->base.base;
125  address.offset = middle->base.insert_index * middle->base.total_bits;
126  bitarr_write_int25(address, middle->base.word_bits, word);
127  address.offset += middle->base.word_bits;
128  at_pointer = address.offset;
129  address.offset += middle->quant_bits;
130  if (order == max_order - 1) {
131  next = ((longest_t *) middle->next_source)->base.insert_index;
132  }
133  else {
134  next = ((middle_t *) middle->next_source)->base.insert_index;
135  }
136 
137  bitarr_write_int25(address, middle->next_mask.bits, next);
138  middle->base.insert_index++;
139  address.offset = at_pointer;
140  return address;
141 }
142 
143 static bitarr_address_t
144 longest_insert(longest_t * longest, uint32 index)
145 {
146  bitarr_address_t address;
147  assert(index <= longest->base.word_mask);
148  address.base = longest->base.base;
149  address.offset = longest->base.insert_index * longest->base.total_bits;
150  bitarr_write_int25(address, longest->base.word_bits, index);
151  address.offset += longest->base.word_bits;
152  longest->base.insert_index++;
153  return address;
154 }
155 
156 static void
157 middle_finish_loading(middle_t * middle, uint32 next_end)
158 {
159  bitarr_address_t address;
160  address.base = middle->base.base;
161  address.offset =
162  (middle->base.insert_index + 1) * middle->base.total_bits -
163  middle->next_mask.bits;
164  bitarr_write_int25(address, middle->next_mask.bits, next_end);
165 }
166 
167 static uint32
168 unigram_next(lm_trie_t * trie, int order)
169 {
170  return order ==
171  2 ? trie->longest->base.insert_index : trie->middle_begin->base.
172  insert_index;
173 }
174 
175 void
176 lm_trie_fix_counts(ngram_raw_t ** raw_ngrams, uint32 * counts,
177  uint32 * fixed_counts, int order)
178 {
179  priority_queue_t *ngrams =
180  priority_queue_create(order - 1, &ngram_ord_comparator);
181  uint32 raw_ngram_ptrs[NGRAM_MAX_ORDER - 1];
182  uint32 words[NGRAM_MAX_ORDER];
183  int i;
184 
185  memset(words, -1, sizeof(words));
186  memcpy(fixed_counts, counts, order * sizeof(*fixed_counts));
187  for (i = 2; i <= order; i++) {
188  ngram_raw_t *tmp_ngram;
189 
190  if (counts[i - 1] <= 0)
191  continue;
192 
193  raw_ngram_ptrs[i - 2] = 0;
194 
195  tmp_ngram =
196  (ngram_raw_t *) ckd_calloc(1, sizeof(*tmp_ngram));
197  *tmp_ngram = raw_ngrams[i - 2][0];
198  tmp_ngram->order = i;
199  priority_queue_add(ngrams, tmp_ngram);
200  }
201 
202  for (;;) {
203  int32 to_increment = TRUE;
204  ngram_raw_t *top;
205  if (priority_queue_size(ngrams) == 0) {
206  break;
207  }
208  top = (ngram_raw_t *) priority_queue_poll(ngrams);
209  if (top->order == 2) {
210  memcpy(words, top->words, 2 * sizeof(*words));
211  }
212  else {
213  for (i = 0; i < top->order - 1; i++) {
214  if (words[i] != top->words[i]) {
215  int num;
216  num = (i == 0) ? 1 : i;
217  memcpy(words, top->words,
218  (num + 1) * sizeof(*words));
219  fixed_counts[num]++;
220  to_increment = FALSE;
221  break;
222  }
223  }
224  words[top->order - 1] = top->words[top->order - 1];
225  }
226  if (to_increment) {
227  raw_ngram_ptrs[top->order - 2]++;
228  }
229  if (raw_ngram_ptrs[top->order - 2] < counts[top->order - 1]) {
230  *top = raw_ngrams[top->order - 2][raw_ngram_ptrs[top->order - 2]];
231  priority_queue_add(ngrams, top);
232  }
233  else {
234  ckd_free(top);
235  }
236  }
237 
238  assert(priority_queue_size(ngrams) == 0);
239  priority_queue_free(ngrams, NULL);
240 }
241 
242 
243 static void
244 recursive_insert(lm_trie_t * trie, ngram_raw_t ** raw_ngrams,
245  uint32 * counts, int order)
246 {
247  uint32 unigram_idx = 0;
248  uint32 *words;
249  float *probs;
250  const uint32 unigram_count = (uint32) counts[0];
251  priority_queue_t *ngrams =
252  priority_queue_create(order, &ngram_ord_comparator);
253  ngram_raw_t *ngram;
254  uint32 *raw_ngrams_ptr;
255  int i;
256 
257  words = (uint32 *) ckd_calloc(order, sizeof(*words));
258  probs = (float *) ckd_calloc(order - 1, sizeof(*probs));
259  ngram = (ngram_raw_t *) ckd_calloc(1, sizeof(*ngram));
260  ngram->order = 1;
261  ngram->words = &unigram_idx;
262  priority_queue_add(ngrams, ngram);
263  raw_ngrams_ptr =
264  (uint32 *) ckd_calloc(order - 1, sizeof(*raw_ngrams_ptr));
265  for (i = 2; i <= order; ++i) {
266  ngram_raw_t *tmp_ngram;
267 
268  if (counts[i - 1] <= 0)
269  continue;
270 
271  raw_ngrams_ptr[i - 2] = 0;
272  tmp_ngram =
273  (ngram_raw_t *) ckd_calloc(1, sizeof(*tmp_ngram));
274  *tmp_ngram = raw_ngrams[i - 2][0];
275  tmp_ngram->order = i;
276 
277  priority_queue_add(ngrams, tmp_ngram);
278  }
279 
280  for (;;) {
281  ngram_raw_t *top =
282  (ngram_raw_t *) priority_queue_poll(ngrams);
283 
284  if (top->order == 1) {
285  trie->unigrams[unigram_idx].next = unigram_next(trie, order);
286  words[0] = unigram_idx;
287  probs[0] = trie->unigrams[unigram_idx].prob;
288  if (++unigram_idx == unigram_count + 1) {
289  ckd_free(top);
290  break;
291  }
292  priority_queue_add(ngrams, top);
293  }
294  else {
295  for (i = 0; i < top->order - 1; i++) {
296  if (words[i] != top->words[i]) {
297  /* need to insert dummy suffixes to make ngram of higher order reachable */
298  int j;
299  assert(i > 0); /* unigrams are not pruned without removing ngrams that contains them */
300  for (j = i; j < top->order - 1; j++) {
301  middle_t *middle = &trie->middle_begin[j - 1];
302  bitarr_address_t address =
303  middle_insert(middle, top->words[j],
304  j + 1, order);
305  /* calculate prob for blank */
306  float calc_prob =
307  probs[j - 1] +
308  trie->unigrams[top->words[j]].bo;
309  probs[j] = calc_prob;
310  lm_trie_quant_mwrite(trie->quant, address, j - 1,
311  calc_prob, 0.0f);
312  }
313  }
314  }
315  memcpy(words, top->words,
316  top->order * sizeof(*words));
317  if (top->order == order) {
318  bitarr_address_t address =
319  longest_insert(trie->longest,
320  top->words[top->order - 1]);
321  lm_trie_quant_lwrite(trie->quant, address, top->prob);
322  }
323  else {
324  middle_t *middle = &trie->middle_begin[top->order - 2];
325  bitarr_address_t address =
326  middle_insert(middle,
327  top->words[top->order - 1],
328  top->order, order);
329  /* write prob and backoff */
330  probs[top->order - 1] = top->prob;
331  lm_trie_quant_mwrite(trie->quant, address, top->order - 2,
332  top->prob, top->backoff);
333  }
334  raw_ngrams_ptr[top->order - 2]++;
335  if (raw_ngrams_ptr[top->order - 2] < counts[top->order - 1]) {
336  *top = raw_ngrams[top->order -
337  2][raw_ngrams_ptr[top->order - 2]];
338 
339  priority_queue_add(ngrams, top);
340  }
341  else {
342  ckd_free(top);
343  }
344  }
345  }
346  assert(priority_queue_size(ngrams) == 0);
347  priority_queue_free(ngrams, NULL);
348  ckd_free(raw_ngrams_ptr);
349  ckd_free(words);
350  ckd_free(probs);
351 }
352 
353 static lm_trie_t *
354 lm_trie_init(uint32 unigram_count)
355 {
356  lm_trie_t *trie;
357 
358  trie = (lm_trie_t *) ckd_calloc(1, sizeof(*trie));
359  memset(trie->hist_cache, -1, sizeof(trie->hist_cache)); /* prepare request history */
360  memset(trie->backoff_cache, 0, sizeof(trie->backoff_cache));
361  trie->unigrams =
362  (unigram_t *) ckd_calloc((unigram_count + 1),
363  sizeof(*trie->unigrams));
364  trie->ngram_mem = NULL;
365  return trie;
366 }
367 
368 lm_trie_t *
369 lm_trie_create(uint32 unigram_count, int order)
370 {
371  lm_trie_t *trie = lm_trie_init(unigram_count);
372  trie->quant =
373  (order > 1) ? lm_trie_quant_create(order) : 0;
374  return trie;
375 }
376 
377 lm_trie_t *
378 lm_trie_read_bin(uint32 * counts, int order, FILE * fp)
379 {
380  lm_trie_t *trie = lm_trie_init(counts[0]);
381  trie->quant = (order > 1) ? lm_trie_quant_read_bin(fp, order) : NULL;
382  fread(trie->unigrams, sizeof(*trie->unigrams), (counts[0] + 1), fp);
383  if (order > 1) {
384  lm_trie_alloc_ngram(trie, counts, order);
385  fread(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
386  }
387  return trie;
388 }
389 
390 void
391 lm_trie_write_bin(lm_trie_t * trie, uint32 unigram_count, FILE * fp)
392 {
393 
394  if (trie->quant)
395  lm_trie_quant_write_bin(trie->quant, fp);
396  fwrite(trie->unigrams, sizeof(*trie->unigrams), (unigram_count + 1),
397  fp);
398  if (trie->ngram_mem)
399  fwrite(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
400 }
401 
402 void
403 lm_trie_free(lm_trie_t * trie)
404 {
405  if (trie->ngram_mem) {
406  ckd_free(trie->ngram_mem);
407  ckd_free(trie->middle_begin);
408  ckd_free(trie->longest);
409  }
410  if (trie->quant)
411  lm_trie_quant_free(trie->quant);
412  ckd_free(trie->unigrams);
413  ckd_free(trie);
414 }
415 
416 static void
417 lm_trie_alloc_ngram(lm_trie_t * trie, uint32 * counts, int order)
418 {
419  int i;
420  uint8 *mem_ptr;
421  uint8 **middle_starts;
422 
423  trie->ngram_mem_size = 0;
424  for (i = 1; i < order - 1; i++) {
425  trie->ngram_mem_size +=
426  middle_size(lm_trie_quant_msize(trie->quant), counts[i],
427  counts[0], counts[i + 1]);
428  }
429  trie->ngram_mem_size +=
430  longest_size(lm_trie_quant_lsize(trie->quant), counts[order - 1],
431  counts[0]);
432  trie->ngram_mem =
433  (uint8 *) ckd_calloc(trie->ngram_mem_size,
434  sizeof(*trie->ngram_mem));
435  mem_ptr = trie->ngram_mem;
436  trie->middle_begin =
437  (middle_t *) ckd_calloc(order - 2, sizeof(*trie->middle_begin));
438  trie->middle_end = trie->middle_begin + (order - 2);
439  middle_starts =
440  (uint8 **) ckd_calloc(order - 2, sizeof(*middle_starts));
441  for (i = 2; i < order; i++) {
442  middle_starts[i - 2] = mem_ptr;
443  mem_ptr +=
444  middle_size(lm_trie_quant_msize(trie->quant), counts[i - 1],
445  counts[0], counts[i]);
446  }
447  trie->longest = (longest_t *) ckd_calloc(1, sizeof(*trie->longest));
448  /* Crazy backwards thing so we initialize using pointers to ones that have already been initialized */
449  for (i = order - 1; i >= 2; --i) {
450  middle_t *middle_ptr = &trie->middle_begin[i - 2];
451  middle_init(middle_ptr, middle_starts[i - 2],
452  lm_trie_quant_msize(trie->quant), counts[i - 1],
453  counts[0], counts[i],
454  (i ==
455  order -
456  1) ? (void *) trie->longest : (void *) &trie->
457  middle_begin[i - 1]);
458  }
459  ckd_free(middle_starts);
460  longest_init(trie->longest, mem_ptr, lm_trie_quant_lsize(trie->quant),
461  counts[0]);
462 }
463 
464 void
465 lm_trie_build(lm_trie_t * trie, ngram_raw_t ** raw_ngrams, uint32 * counts, uint32 *out_counts,
466  int order)
467 {
468  int i;
469 
470  lm_trie_fix_counts(raw_ngrams, counts, out_counts, order);
471  lm_trie_alloc_ngram(trie, out_counts, order);
472 
473  if (order > 1)
474  E_INFO("Training quantizer\n");
475  for (i = 2; i < order; i++) {
476  lm_trie_quant_train(trie->quant, i, counts[i - 1],
477  raw_ngrams[i - 2]);
478  }
479  lm_trie_quant_train_prob(trie->quant, order, counts[order - 1],
480  raw_ngrams[order - 2]);
481 
482  E_INFO("Building LM trie\n");
483  recursive_insert(trie, raw_ngrams, counts, order);
484  /* Set ending offsets so the last entry will be sized properly */
485  /* Last entry for unigrams was already set. */
486  if (trie->middle_begin != trie->middle_end) {
487  middle_t *middle_ptr;
488  for (middle_ptr = trie->middle_begin;
489  middle_ptr != trie->middle_end - 1; ++middle_ptr) {
490  middle_t *next_middle_ptr = middle_ptr + 1;
491  middle_finish_loading(middle_ptr,
492  next_middle_ptr->base.insert_index);
493  }
494  middle_ptr = trie->middle_end - 1;
495  middle_finish_loading(middle_ptr,
496  trie->longest->base.insert_index);
497  }
498 }
499 
500 unigram_t *
501 unigram_find(unigram_t * u, uint32 word, node_range_t * next)
502 {
503  unigram_t *ptr = &u[word];
504  next->begin = ptr->next;
505  next->end = (ptr + 1)->next;
506  return ptr;
507 }
508 
509 static size_t
510 calc_pivot(uint32 off, uint32 range, uint32 width)
511 {
512  return (size_t) ((off * width) / (range + 1));
513 }
514 
515 static uint8
516 uniform_find(void *base, uint8 total_bits, uint8 key_bits, uint32 key_mask,
517  uint32 before_it, uint32 before_v,
518  uint32 after_it, uint32 after_v, uint32 key, uint32 * out)
519 {
520  bitarr_address_t address;
521  address.base = base;
522  while (after_it - before_it > 1) {
523  uint32 mid;
524  uint32 pivot =
525  before_it + (1 +
526  calc_pivot(key - before_v, after_v - before_v,
527  after_it - before_it - 1));
528  /* access by pivot */
529  address.offset = pivot * (uint32) total_bits;
530  mid = bitarr_read_int25(address, key_bits, key_mask);
531  if (mid < key) {
532  before_it = pivot;
533  before_v = mid;
534  }
535  else if (mid > key) {
536  after_it = pivot;
537  after_v = mid;
538  }
539  else {
540  *out = pivot;
541  return TRUE;
542  }
543  }
544  return FALSE;
545 }
546 
547 static bitarr_address_t
548 middle_find(middle_t * middle, uint32 word, node_range_t * range)
549 {
550  uint32 at_pointer;
551  bitarr_address_t address;
552 
553  /* finding BitPacked with uniform find */
554  if (!uniform_find
555  ((void *) middle->base.base, middle->base.total_bits,
556  middle->base.word_bits, middle->base.word_mask, range->begin - 1,
557  0, range->end, middle->base.max_vocab, word, &at_pointer)) {
558  address.base = NULL;
559  address.offset = 0;
560  return address;
561  }
562 
563  address.base = middle->base.base;
564  at_pointer *= middle->base.total_bits;
565  at_pointer += middle->base.word_bits;
566  address.offset = at_pointer + middle->quant_bits;
567  range->begin =
568  bitarr_read_int25(address, middle->next_mask.bits,
569  middle->next_mask.mask);
570  address.offset += middle->base.total_bits;
571  range->end =
572  bitarr_read_int25(address, middle->next_mask.bits,
573  middle->next_mask.mask);
574  address.offset = at_pointer;
575 
576  return address;
577 }
578 
579 static bitarr_address_t
580 longest_find(longest_t * longest, uint32 word, node_range_t * range)
581 {
582  uint32 at_pointer;
583  bitarr_address_t address;
584 
585  /* finding BitPacked with uniform find */
586  if (!uniform_find
587  ((void *) longest->base.base, longest->base.total_bits,
588  longest->base.word_bits, longest->base.word_mask,
589  range->begin - 1, 0, range->end, longest->base.max_vocab, word,
590  &at_pointer)) {
591  address.base = NULL;
592  address.offset = 0;
593  return address;
594  }
595  address.base = longest->base.base;
596  address.offset =
597  at_pointer * longest->base.total_bits + longest->base.word_bits;
598  return address;
599 }
600 
601 static float
602 get_available_prob(lm_trie_t * trie, int32 wid, int32 * hist,
603  int max_order, int32 n_hist, int32 * n_used)
604 {
605  float prob;
606  node_range_t node;
607  bitarr_address_t address;
608  int order_minus_2;
609  uint8 independent_left;
610  int32 *hist_iter, *hist_end;
611 
612  *n_used = 1;
613  prob = unigram_find(trie->unigrams, wid, &node)->prob;
614  if (n_hist == 0) {
615  return prob;
616  }
617 
618  /* find ngrams of higher order if any */
619  order_minus_2 = 0;
620  independent_left = (node.begin == node.end);
621  hist_iter = hist;
622  hist_end = hist + n_hist;
623  for (;; order_minus_2++, hist_iter++) {
624  if (hist_iter == hist_end)
625  return prob;
626  if (independent_left)
627  return prob;
628  if (order_minus_2 == max_order - 2)
629  break;
630 
631  address =
632  middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
633  &node);
634  independent_left = (address.base == NULL)
635  || (node.begin == node.end);
636 
637  /* didn't find entry */
638  if (address.base == NULL)
639  return prob;
640  prob = lm_trie_quant_mpread(trie->quant, address, order_minus_2);
641  *n_used = order_minus_2 + 2;
642  }
643 
644  address = longest_find(trie->longest, *hist_iter, &node);
645  if (address.base != NULL) {
646  prob = lm_trie_quant_lpread(trie->quant, address);
647  *n_used = max_order;
648  }
649  return prob;
650 }
651 
652 static float
653 get_available_backoff(lm_trie_t * trie, int32 start, int32 * hist,
654  int32 n_hist)
655 {
656  float backoff = 0.0f;
657  int order_minus_2;
658  int32 *hist_iter;
659  node_range_t node;
660  unigram_t *first_hist = unigram_find(trie->unigrams, hist[0], &node);
661  if (start <= 1) {
662  backoff += first_hist->bo;
663  start = 2;
664  }
665  order_minus_2 = start - 2;
666  for (hist_iter = hist + start - 1; hist_iter < hist + n_hist;
667  hist_iter++, order_minus_2++) {
668  bitarr_address_t address =
669  middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
670  &node);
671  if (address.base == NULL)
672  break;
673  backoff +=
674  lm_trie_quant_mboread(trie->quant, address, order_minus_2);
675  }
676  return backoff;
677 }
678 
679 static float
680 lm_trie_nobo_score(lm_trie_t * trie, int32 wid, int32 * hist,
681  int max_order, int32 n_hist, int32 * n_used)
682 {
683  float prob =
684  get_available_prob(trie, wid, hist, max_order, n_hist, n_used);
685  if (n_hist < *n_used)
686  return prob;
687  return prob + get_available_backoff(trie, *n_used, hist, n_hist);
688 }
689 
690 static float
691 lm_trie_hist_score(lm_trie_t * trie, int32 wid, int32 * hist, int32 n_hist,
692  int32 * n_used)
693 {
694  float prob;
695  int i, j;
696  node_range_t node;
697  bitarr_address_t address;
698 
699  *n_used = 1;
700  prob = unigram_find(trie->unigrams, wid, &node)->prob;
701  if (n_hist == 0)
702  return prob;
703  for (i = 0; i < n_hist - 1; i++) {
704  address = middle_find(&trie->middle_begin[i], hist[i], &node);
705  if (address.base == NULL) {
706  for (j = i; j < n_hist; j++) {
707  prob += trie->backoff_cache[j];
708  }
709  return prob;
710  }
711  else {
712  (*n_used)++;
713  prob = lm_trie_quant_mpread(trie->quant, address, i);
714  }
715  }
716  address = longest_find(trie->longest, hist[n_hist - 1], &node);
717  if (address.base == NULL) {
718  return prob + trie->backoff_cache[n_hist - 1];
719  }
720  else {
721  (*n_used)++;
722  return lm_trie_quant_lpread(trie->quant, address);
723  }
724 }
725 
726 static uint8
727 history_matches(int32 * hist, int32 * prev_hist, int32 n_hist)
728 {
729  int i;
730  for (i = 0; i < n_hist; i++) {
731  if (hist[i] != prev_hist[i]) {
732  return FALSE;
733  }
734  }
735  return TRUE;
736 }
737 
738 static void
739 update_backoff(lm_trie_t * trie, int32 * hist, int32 n_hist)
740 {
741  int i;
742  node_range_t node;
743  bitarr_address_t address;
744 
745  memset(trie->backoff_cache, 0, sizeof(trie->backoff_cache));
746  trie->backoff_cache[0] = unigram_find(trie->unigrams, hist[0], &node)->bo;
747  for (i = 1; i < n_hist; i++) {
748  address = middle_find(&trie->middle_begin[i - 1], hist[i], &node);
749  if (address.base == NULL) {
750  break;
751  }
752  trie->backoff_cache[i] =
753  lm_trie_quant_mboread(trie->quant, address, i - 1);
754  }
755  memcpy(trie->hist_cache, hist, n_hist * sizeof(*hist));
756 }
757 
758 float
759 lm_trie_score(lm_trie_t * trie, int order, int32 wid, int32 * hist,
760  int32 n_hist, int32 * n_used)
761 {
762  if (n_hist < order - 1) {
763  return lm_trie_nobo_score(trie, wid, hist, order, n_hist, n_used);
764  }
765  else {
766  assert(n_hist == order - 1);
767  if (!history_matches(hist, (int32 *) trie->hist_cache, n_hist)) {
768  update_backoff(trie, hist, n_hist);
769  }
770  return lm_trie_hist_score(trie, wid, hist, n_hist, n_used);
771  }
772 }
773 
774 void
775 lm_trie_fill_raw_ngram(lm_trie_t * trie,
776  ngram_raw_t * raw_ngrams, uint32 * raw_ngram_idx,
777  uint32 * counts, node_range_t range, uint32 * hist,
778  int n_hist, int order, int max_order)
779 {
780  if (n_hist > 0 && range.begin == range.end) {
781  return;
782  }
783  if (n_hist == 0) {
784  uint32 i;
785  for (i = 0; i < counts[0]; i++) {
786  node_range_t node;
787  unigram_find(trie->unigrams, i, &node);
788  hist[0] = i;
789  lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
790  node, hist, 1, order, max_order);
791  }
792  }
793  else if (n_hist < order - 1) {
794  uint32 ptr;
795  node_range_t node;
796  bitarr_address_t address;
797  uint32 new_word;
798  middle_t *middle = &trie->middle_begin[n_hist - 1];
799  for (ptr = range.begin; ptr < range.end; ptr++) {
800  address.base = middle->base.base;
801  address.offset = ptr * middle->base.total_bits;
802  new_word =
803  bitarr_read_int25(address, middle->base.word_bits,
804  middle->base.word_mask);
805  hist[n_hist] = new_word;
806  address.offset += middle->base.word_bits + middle->quant_bits;
807  node.begin =
808  bitarr_read_int25(address, middle->next_mask.bits,
809  middle->next_mask.mask);
810  address.offset =
811  (ptr + 1) * middle->base.total_bits +
812  middle->base.word_bits + middle->quant_bits;
813  node.end =
814  bitarr_read_int25(address, middle->next_mask.bits,
815  middle->next_mask.mask);
816  lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
817  node, hist, n_hist + 1, order, max_order);
818  }
819  }
820  else {
821  bitarr_address_t address;
822  uint32 ptr;
823  float prob, backoff;
824  int i;
825  assert(n_hist == order - 1);
826  for (ptr = range.begin; ptr < range.end; ptr++) {
827  ngram_raw_t *raw_ngram = &raw_ngrams[*raw_ngram_idx];
828  if (order == max_order) {
829  longest_t *longest = trie->longest;
830  address.base = longest->base.base;
831  address.offset = ptr * longest->base.total_bits;
832  hist[n_hist] =
833  bitarr_read_int25(address, longest->base.word_bits,
834  longest->base.word_mask);
835  address.offset += longest->base.word_bits;
836  prob = lm_trie_quant_lpread(trie->quant, address);
837  }
838  else {
839  middle_t *middle = &trie->middle_begin[n_hist - 1];
840  address.base = middle->base.base;
841  address.offset = ptr * middle->base.total_bits;
842  hist[n_hist] =
843  bitarr_read_int25(address, middle->base.word_bits,
844  middle->base.word_mask);
845  address.offset += middle->base.word_bits;
846  prob =
847  lm_trie_quant_mpread(trie->quant, address, n_hist - 1);
848  backoff =
849  lm_trie_quant_mboread(trie->quant, address,
850  n_hist - 1);
851  raw_ngram->backoff = backoff;
852  }
853  raw_ngram->prob = prob;
854  raw_ngram->words =
855  (uint32 *) ckd_calloc(order, sizeof(*raw_ngram->words));
856  for (i = 0; i <= n_hist; i++) {
857  raw_ngram->words[i] = hist[n_hist - i];
858  }
859  (*raw_ngram_idx)++;
860  }
861  }
862 }
priority_queue_s
Definition: priority_queue.c:46
bitarr_write_int25
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
Definition: bitarr.c:112
prim_type.h
Basic type definitions used in Sphinx.
E_INFO
#define E_INFO(...)
Print logging information to standard error stream.
Definition: err.h:114
bitarr_address_s
Structure that stores address of certain value in bit array.
Definition: bitarr.h:73
lm_trie_s
Definition: lm_trie.h:79
longest_s
Definition: lm_trie.h:74
ckd_free
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Definition: ckd_alloc.c:244
node_range_s
Definition: lm_trie.h:53
U
Definition: dtoa.c:178
bitarr_required_bits
SPHINXBASE_EXPORT uint8 bitarr_required_bits(uint32 max_value)
Computes amount of bits required ti store integers upto value provided.
Definition: bitarr.c:131
bitarr_mask_from_max
SPHINXBASE_EXPORT void bitarr_mask_from_max(bitarr_mask_t *bit_mask, uint32 max_value)
Fills mask for certain int range according to provided max value.
Definition: bitarr.c:125
err.h
Implementation of logging routines.
base_s
Definition: lm_trie.h:58
ngram_raw_s
Definition: ngrams_raw.h:47
E_ERROR
#define E_ERROR(...)
Print error message to error log.
Definition: err.h:104
bitarr_read_int25
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
Definition: bitarr.c:100
middle_s
Definition: lm_trie.h:67
ckd_alloc.h
Sphinx's memory allocation/deallocation routines.
ckd_calloc
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
Definition: ckd_alloc.h:248
unigram_s
Definition: lm_trie.h:47