• Main Page
  • Related Pages
  • Data Structures
  • Files
  • File List
  • Globals

src/libsphinxbase/lm/ngram_model_arpa.c

00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 1999-2007 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 /*
00038  * \file ngram_model_arpa.c ARPA format language models
00039  *
00040  * Author: David Huggins-Daines <dhuggins@cs.cmu.edu>
00041  */
00042 
00043 #include "ckd_alloc.h"
00044 #include "ngram_model_arpa.h"
00045 #include "err.h"
00046 #include "pio.h"
00047 #include "listelem_alloc.h"
00048 #include "strfuncs.h"
00049 
00050 #include <string.h>
00051 #include <limits.h>
00052 
00053 static ngram_funcs_t ngram_model_arpa_funcs;
00054 
00055 #define TSEG_BASE(m,b)          ((m)->lm3g.tseg_base[(b)>>LOG_BG_SEG_SZ])
00056 #define FIRST_BG(m,u)           ((m)->lm3g.unigrams[u].bigrams)
00057 #define FIRST_TG(m,b)           (TSEG_BASE((m),(b))+((m)->lm3g.bigrams[b].trigrams))
00058 
00059 /*
00060  * Initialize sorted list with the 0-th entry = MIN_PROB_F, which may be needed
00061  * to replace spurious values in the Darpa LM file.
00062  */
00063 static void
00064 init_sorted_list(sorted_list_t * l)
00065 {
00066     /* FIXME FIXME FIXME: Fixed size array!??! */
00067     l->list = ckd_calloc(MAX_SORTED_ENTRIES,
00068                          sizeof(sorted_entry_t));
00069     l->list[0].val.l = INT_MIN;
00070     l->list[0].lower = 0;
00071     l->list[0].higher = 0;
00072     l->free = 1;
00073 }
00074 
00075 static void
00076 free_sorted_list(sorted_list_t * l)
00077 {
00078     free(l->list);
00079 }
00080 
00081 static lmprob_t *
00082 vals_in_sorted_list(sorted_list_t * l)
00083 {
00084     lmprob_t *vals;
00085     int32 i;
00086 
00087     vals = ckd_calloc(l->free, sizeof(lmprob_t));
00088     for (i = 0; i < l->free; i++)
00089         vals[i] = l->list[i].val;
00090     return (vals);
00091 }
00092 
00093 static int32
00094 sorted_id(sorted_list_t * l, int32 *val)
00095 {
00096     int32 i = 0;
00097 
00098     for (;;) {
00099         if (*val == l->list[i].val.l)
00100             return (i);
00101         if (*val < l->list[i].val.l) {
00102             if (l->list[i].lower == 0) {
00103                 if (l->free >= MAX_SORTED_ENTRIES) {
00104                     /* Make the best of a bad situation. */
00105                     E_WARN("sorted list overflow (%d => %d)\n",
00106                            *val, l->list[i].val.l);
00107                     return i;
00108                 }
00109 
00110                 l->list[i].lower = l->free;
00111                 (l->free)++;
00112                 i = l->list[i].lower;
00113                 l->list[i].val.l = *val;
00114                 return (i);
00115             }
00116             else
00117                 i = l->list[i].lower;
00118         }
00119         else {
00120             if (l->list[i].higher == 0) {
00121                 if (l->free >= MAX_SORTED_ENTRIES) {
00122                     /* Make the best of a bad situation. */
00123                     E_WARN("sorted list overflow (%d => %d)\n",
00124                            *val, l->list[i].val);
00125                     return i;
00126                 }
00127 
00128                 l->list[i].higher = l->free;
00129                 (l->free)++;
00130                 i = l->list[i].higher;
00131                 l->list[i].val.l = *val;
00132                 return (i);
00133             }
00134             else
00135                 i = l->list[i].higher;
00136         }
00137     }
00138 }
00139 
00140 /*
00141  * Read and return #unigrams, #bigrams, #trigrams as stated in input file.
00142  */
00143 static int
00144 ReadNgramCounts(FILE * fp, int32 * n_ug, int32 * n_bg, int32 * n_tg)
00145 {
00146     char string[256];
00147     int32 ngram, ngram_cnt;
00148 
00149     /* skip file until past the '\data\' marker */
00150     do
00151         fgets(string, sizeof(string), fp);
00152     while ((strcmp(string, "\\data\\\n") != 0) && (!feof(fp)));
00153 
00154     if (strcmp(string, "\\data\\\n") != 0) {
00155         E_ERROR("No \\data\\ mark in LM file\n");
00156         return -1;
00157     }
00158 
00159     *n_ug = *n_bg = *n_tg = 0;
00160     while (fgets(string, sizeof(string), fp) != NULL) {
00161         if (sscanf(string, "ngram %d=%d", &ngram, &ngram_cnt) != 2)
00162             break;
00163         switch (ngram) {
00164         case 1:
00165             *n_ug = ngram_cnt;
00166             break;
00167         case 2:
00168             *n_bg = ngram_cnt;
00169             break;
00170         case 3:
00171             *n_tg = ngram_cnt;
00172             break;
00173         default:
00174             E_ERROR("Unknown ngram (%d)\n", ngram);
00175             return -1;
00176         }
00177     }
00178 
00179     /* Position file to just after the unigrams header '\1-grams:\' */
00180     while ((strcmp(string, "\\1-grams:\n") != 0) && (!feof(fp)))
00181         fgets(string, sizeof(string), fp);
00182 
00183     /* Check counts;  NOTE: #trigrams *CAN* be 0 */
00184     if ((*n_ug <= 0) || (*n_bg <= 0) || (*n_tg < 0)) {
00185         E_ERROR("Bad or missing ngram count\n");
00186         return -1;
00187     }
00188     return 0;
00189 }
00190 
00191 /*
00192  * Read in the unigrams from given file into the LM structure model.  On
00193  * entry to this procedure, the file pointer is positioned just after the
00194  * header line '\1-grams:'.
00195  */
00196 static int
00197 ReadUnigrams(FILE * fp, ngram_model_arpa_t * model)
00198 {
00199     ngram_model_t *base = &model->base;
00200     char string[256];
00201     int32 wcnt;
00202     float p1;
00203 
00204     E_INFO("Reading unigrams\n");
00205 
00206     wcnt = 0;
00207     while ((fgets(string, sizeof(string), fp) != NULL) &&
00208            (strcmp(string, "\\2-grams:\n") != 0)) {
00209         char *wptr[3], *name;
00210         float32 bo_wt = 0.0f;
00211         int n;
00212 
00213         if ((n = str2words(string, wptr, 3)) < 2) {
00214             if (string[0] != '\n')
00215                 E_WARN("Format error; unigram ignored: %s\n", string);
00216             continue;
00217         }
00218         else {
00219             p1 = (float)atof_c(wptr[0]);
00220             name = wptr[1];
00221             if (n == 3)
00222                 bo_wt = (float)atof_c(wptr[2]);
00223         }
00224 
00225         if (wcnt >= base->n_counts[0]) {
00226             E_ERROR("Too many unigrams\n");
00227             return -1;
00228         }
00229 
00230         /* Associate name with word id */
00231         base->word_str[wcnt] = ckd_salloc(name);
00232         if ((hash_table_enter(base->wid, base->word_str[wcnt], (void *)(long)wcnt))
00233             != (void *)(long)wcnt) {
00234                 E_WARN("Duplicate word in dictionary: %s\n", base->word_str[wcnt]);
00235         }
00236         model->lm3g.unigrams[wcnt].prob1.l = logmath_log10_to_log(base->lmath, p1);
00237         model->lm3g.unigrams[wcnt].bo_wt1.l = logmath_log10_to_log(base->lmath, bo_wt);
00238         wcnt++;
00239     }
00240 
00241     if (base->n_counts[0] != wcnt) {
00242         E_WARN("lm_t.ucount(%d) != #unigrams read(%d)\n",
00243                base->n_counts[0], wcnt);
00244         base->n_counts[0] = wcnt;
00245     }
00246     return 0;
00247 }
00248 
00249 /*
00250  * Read bigrams from given file into given model structure.
00251  */
00252 static int
00253 ReadBigrams(FILE * fp, ngram_model_arpa_t * model)
00254 {
00255     ngram_model_t *base = &model->base;
00256     char string[1024];
00257     int32 w1, w2, prev_w1, bgcount;
00258     bigram_t *bgptr;
00259 
00260     E_INFO("Reading bigrams\n");
00261 
00262     bgcount = 0;
00263     bgptr = model->lm3g.bigrams;
00264     prev_w1 = -1;
00265 
00266     while (fgets(string, sizeof(string), fp) != NULL) {
00267         float32 p, bo_wt = 0.0f;
00268         int32 p2, bo_wt2;
00269         char *wptr[4], *word1, *word2;
00270         int n;
00271 
00272         wptr[3] = NULL;
00273         if ((n = str2words(string, wptr, 4)) < 3) {
00274             if (string[0] != '\n')
00275                 break;
00276             continue;
00277         }
00278         else {
00279             p = (float32)atof_c(wptr[0]);
00280             word1 = wptr[1];
00281             word2 = wptr[2];
00282             if (wptr[3])
00283                 bo_wt = (float32)atof_c(wptr[3]);
00284         }
00285 
00286         if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00287             E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00288                     word1, word1, word2);
00289             continue;
00290         }
00291         if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00292             E_ERROR("Unknown word: %s, skipping bigram (%s %s)\n",
00293                     word2, word1, word2);
00294             continue;
00295         }
00296 
00297         /* FIXME: Should use logmath_t quantization here. */
00298         /* HACK!! to quantize probs to 4 decimal digits */
00299         p = (float32)((int32)(p * 10000)) / 10000;
00300         bo_wt = (float32)((int32)(bo_wt * 10000)) / 10000;
00301 
00302         p2 = logmath_log10_to_log(base->lmath, p);
00303         bo_wt2 = logmath_log10_to_log(base->lmath, bo_wt);
00304 
00305         if (bgcount >= base->n_counts[1]) {
00306             E_ERROR("Too many bigrams\n");
00307             return -1;
00308         }
00309 
00310         bgptr->wid = w2;
00311         bgptr->prob2 = sorted_id(&model->sorted_prob2, &p2);
00312         if (base->n_counts[2] > 0)
00313             bgptr->bo_wt2 = sorted_id(&model->sorted_bo_wt2, &bo_wt2);
00314 
00315         if (w1 != prev_w1) {
00316             if (w1 < prev_w1) {
00317                 E_ERROR("Bigrams not in unigram order\n");
00318                 return -1;
00319             }
00320 
00321             for (prev_w1++; prev_w1 <= w1; prev_w1++)
00322                 model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00323             prev_w1 = w1;
00324         }
00325 
00326         bgcount++;
00327         bgptr++;
00328 
00329         if ((bgcount & 0x0000ffff) == 0) {
00330             E_INFOCONT(".");
00331         }
00332     }
00333     if ((strcmp(string, "\\end\\") != 0)
00334         && (strcmp(string, "\\3-grams:") != 0)) {
00335         E_ERROR("Bad bigram: %s\n", string);
00336         return -1;
00337     }
00338 
00339     for (prev_w1++; prev_w1 <= base->n_counts[0]; prev_w1++)
00340         model->lm3g.unigrams[prev_w1].bigrams = bgcount;
00341 
00342     return 0;
00343 }
00344 
00345 /*
00346  * Very similar to ReadBigrams.
00347  */
00348 static int
00349 ReadTrigrams(FILE * fp, ngram_model_arpa_t * model)
00350 {
00351     ngram_model_t *base = &model->base;
00352     char string[1024];
00353     int32 i, w1, w2, w3, prev_w1, prev_w2, tgcount, prev_bg, bg, endbg;
00354     int32 seg, prev_seg, prev_seg_lastbg;
00355     trigram_t *tgptr;
00356     bigram_t *bgptr;
00357 
00358     E_INFO("Reading trigrams\n");
00359 
00360     tgcount = 0;
00361     tgptr = model->lm3g.trigrams;
00362     prev_w1 = -1;
00363     prev_w2 = -1;
00364     prev_bg = -1;
00365     prev_seg = -1;
00366 
00367     while (fgets(string, sizeof(string), fp) != NULL) {
00368         float32 p;
00369         int32 p3;
00370         char *wptr[4], *word1, *word2, *word3;
00371 
00372         if (str2words(string, wptr, 4) != 4) {
00373             if (string[0] != '\n')
00374                 break;
00375             continue;
00376         }
00377         else {
00378             p = (float32)atof_c(wptr[0]);
00379             word1 = wptr[1];
00380             word2 = wptr[2];
00381             word3 = wptr[3];
00382         }
00383 
00384         if ((w1 = ngram_wid(base, word1)) == NGRAM_INVALID_WID) {
00385             E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00386                     word1, word1, word2, word3);
00387             continue;
00388         }
00389         if ((w2 = ngram_wid(base, word2)) == NGRAM_INVALID_WID) {
00390             E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00391                     word2, word1, word2, word3);
00392             continue;
00393         }
00394         if ((w3 = ngram_wid(base, word3)) == NGRAM_INVALID_WID) {
00395             E_ERROR("Unknown word: %s, skipping trigram (%s %s %s)\n",
00396                     word3, word1, word2, word3);
00397             continue;
00398         }
00399 
00400         /* FIXME: Should use logmath_t quantization here. */
00401         /* HACK!! to quantize probs to 4 decimal digits */
00402         p = (float32)((int32)(p * 10000)) / 10000;
00403         p3 = logmath_log10_to_log(base->lmath, p);
00404 
00405         if (tgcount >= base->n_counts[2]) {
00406             E_ERROR("Too many trigrams\n");
00407             return -1;
00408         }
00409 
00410         tgptr->wid = w3;
00411         tgptr->prob3 = sorted_id(&model->sorted_prob3, &p3);
00412 
00413         if ((w1 != prev_w1) || (w2 != prev_w2)) {
00414             /* Trigram for a new bigram; update tg info for all previous bigrams */
00415             if ((w1 < prev_w1) || ((w1 == prev_w1) && (w2 < prev_w2))) {
00416                 E_ERROR("Trigrams not in bigram order\n");
00417                 return -1;
00418             }
00419 
00420             bg = (w1 !=
00421                   prev_w1) ? model->lm3g.unigrams[w1].bigrams : prev_bg + 1;
00422             endbg = model->lm3g.unigrams[w1 + 1].bigrams;
00423             bgptr = model->lm3g.bigrams + bg;
00424             for (; (bg < endbg) && (bgptr->wid != w2); bg++, bgptr++);
00425             if (bg >= endbg) {
00426                 E_ERROR("Missing bigram for trigram: %s", string);
00427                 return -1;
00428             }
00429 
00430             /* bg = bigram entry index for <w1,w2>.  Update tseg_base */
00431             seg = bg >> LOG_BG_SEG_SZ;
00432             for (i = prev_seg + 1; i <= seg; i++)
00433                 model->lm3g.tseg_base[i] = tgcount;
00434 
00435             /* Update trigrams pointers for all bigrams until bg */
00436             if (prev_seg < seg) {
00437                 int32 tgoff = 0;
00438 
00439                 if (prev_seg >= 0) {
00440                     tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00441                     if (tgoff > 65535) {
00442                         E_ERROR("Offset from tseg_base > 65535\n");
00443                         return -1;
00444                     }
00445                 }
00446 
00447                 prev_seg_lastbg = ((prev_seg + 1) << LOG_BG_SEG_SZ) - 1;
00448                 bgptr = model->lm3g.bigrams + prev_bg;
00449                 for (++prev_bg, ++bgptr; prev_bg <= prev_seg_lastbg;
00450                      prev_bg++, bgptr++)
00451                     bgptr->trigrams = tgoff;
00452 
00453                 for (; prev_bg <= bg; prev_bg++, bgptr++)
00454                     bgptr->trigrams = 0;
00455             }
00456             else {
00457                 int32 tgoff;
00458 
00459                 tgoff = tgcount - model->lm3g.tseg_base[prev_seg];
00460                 if (tgoff > 65535) {
00461                     E_ERROR("Offset from tseg_base > 65535\n");
00462                     return -1;
00463                 }
00464 
00465                 bgptr = model->lm3g.bigrams + prev_bg;
00466                 for (++prev_bg, ++bgptr; prev_bg <= bg; prev_bg++, bgptr++)
00467                     bgptr->trigrams = tgoff;
00468             }
00469 
00470             prev_w1 = w1;
00471             prev_w2 = w2;
00472             prev_bg = bg;
00473             prev_seg = seg;
00474         }
00475 
00476         tgcount++;
00477         tgptr++;
00478 
00479         if ((tgcount & 0x0000ffff) == 0) {
00480             E_INFOCONT(".");
00481         }
00482     }
00483     if (strcmp(string, "\\end\\") != 0) {
00484         E_ERROR("Bad trigram: %s\n", string);
00485         return -1;
00486     }
00487 
00488     for (prev_bg++; prev_bg <= base->n_counts[1]; prev_bg++) {
00489         if ((prev_bg & (BG_SEG_SZ - 1)) == 0)
00490             model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ] = tgcount;
00491         if ((tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ]) > 65535) {
00492             E_ERROR("Offset from tseg_base > 65535\n");
00493             return -1;
00494         }
00495         model->lm3g.bigrams[prev_bg].trigrams =
00496             tgcount - model->lm3g.tseg_base[prev_bg >> LOG_BG_SEG_SZ];
00497     }
00498     return 0;
00499 }
00500 
00501 static unigram_t *
00502 new_unigram_table(int32 n_ug)
00503 {
00504     unigram_t *table;
00505     int32 i;
00506 
00507     table = ckd_calloc(n_ug, sizeof(unigram_t));
00508     for (i = 0; i < n_ug; i++) {
00509         table[i].prob1.l = INT_MIN;
00510         table[i].bo_wt1.l = INT_MIN;
00511     }
00512     return table;
00513 }
00514 
00515 ngram_model_t *
00516 ngram_model_arpa_read(cmd_ln_t *config,
00517                       const char *file_name,
00518                       logmath_t *lmath)
00519 {
00520     FILE *fp;
00521     int32 is_pipe;
00522     int32 n_unigram;
00523     int32 n_bigram;
00524     int32 n_trigram;
00525     int32 n;
00526     ngram_model_arpa_t *model;
00527     ngram_model_t *base;
00528 
00529     if ((fp = fopen_comp(file_name, "r", &is_pipe)) == NULL) {
00530         E_ERROR("File %s not found\n", file_name);
00531         return NULL;
00532     }
00533  
00534     /* Read #unigrams, #bigrams, #trigrams from file */
00535     if (ReadNgramCounts(fp, &n_unigram, &n_bigram, &n_trigram) == -1) {
00536         fclose_comp(fp, is_pipe);
00537         return NULL;
00538     }
00539     E_INFO("ngrams 1=%d, 2=%d, 3=%d\n", n_unigram, n_bigram, n_trigram);
00540 
00541     /* Allocate space for LM, including initial OOVs and placeholders; initialize it */
00542     model = ckd_calloc(1, sizeof(*model));
00543     base = &model->base;
00544     if (n_trigram > 0)
00545         n = 3;
00546     else if (n_bigram > 0)
00547         n = 2;
00548     else
00549         n = 1;
00550     /* Initialize base model. */
00551     ngram_model_init(base, &ngram_model_arpa_funcs, lmath, n, n_unigram);
00552     base->n_counts[0] = n_unigram;
00553     base->n_counts[1] = n_bigram;
00554     base->n_counts[2] = n_trigram;
00555     base->writable = TRUE;
00556 
00557     /*
00558      * Allocate one extra unigram and bigram entry: sentinels to terminate
00559      * followers (bigrams and trigrams, respectively) of previous entry.
00560      */
00561     model->lm3g.unigrams = new_unigram_table(n_unigram + 1);
00562     model->lm3g.bigrams =
00563         ckd_calloc(n_bigram + 1, sizeof(bigram_t));
00564     if (n_trigram > 0)
00565         model->lm3g.trigrams =
00566             ckd_calloc(n_trigram, sizeof(trigram_t));
00567 
00568     if (n_trigram > 0) {
00569         model->lm3g.tseg_base =
00570             ckd_calloc((n_bigram + 1) / BG_SEG_SZ + 1,
00571                        sizeof(int32));
00572     }
00573     if (ReadUnigrams(fp, model) == -1) {
00574         fclose_comp(fp, is_pipe);
00575         ngram_model_free(base);
00576         return NULL;
00577     }
00578     E_INFO("%8d = #unigrams created\n", base->n_counts[0]);
00579 
00580     init_sorted_list(&model->sorted_prob2);
00581     if (base->n_counts[2] > 0)
00582         init_sorted_list(&model->sorted_bo_wt2);
00583 
00584     if (ReadBigrams(fp, model) == -1) {
00585         fclose_comp(fp, is_pipe);
00586         ngram_model_free(base);
00587         return NULL;
00588     }
00589 
00590     base->n_counts[1] = FIRST_BG(model, base->n_counts[0]);
00591     model->lm3g.n_prob2 = model->sorted_prob2.free;
00592     model->lm3g.prob2 = vals_in_sorted_list(&model->sorted_prob2);
00593     free_sorted_list(&model->sorted_prob2);
00594     E_INFO("%8d = #bigrams created\n", base->n_counts[1]);
00595     E_INFO("%8d = #prob2 entries\n", model->lm3g.n_prob2);
00596 
00597     if (base->n_counts[2] > 0) {
00598         /* Create trigram bo-wts array */
00599         model->lm3g.n_bo_wt2 = model->sorted_bo_wt2.free;
00600         model->lm3g.bo_wt2 = vals_in_sorted_list(&model->sorted_bo_wt2);
00601         free_sorted_list(&model->sorted_bo_wt2);
00602         E_INFO("%8d = #bo_wt2 entries\n", model->lm3g.n_bo_wt2);
00603 
00604         init_sorted_list(&model->sorted_prob3);
00605 
00606         if (ReadTrigrams(fp, model) == -1) {
00607             fclose_comp(fp, is_pipe);
00608             ngram_model_free(base);
00609             return NULL;
00610         }
00611 
00612         base->n_counts[2] = FIRST_TG(model, base->n_counts[1]);
00613         model->lm3g.n_prob3 = model->sorted_prob3.free;
00614         model->lm3g.prob3 = vals_in_sorted_list(&model->sorted_prob3);
00615         E_INFO("%8d = #trigrams created\n", base->n_counts[2]);
00616         E_INFO("%8d = #prob3 entries\n", model->lm3g.n_prob3);
00617 
00618         free_sorted_list(&model->sorted_prob3);
00619 
00620         /* Initialize tginfo */
00621         model->lm3g.tginfo = ckd_calloc(n_unigram, sizeof(tginfo_t *));
00622         model->lm3g.le = listelem_alloc_init(sizeof(tginfo_t));
00623     }
00624 
00625     fclose_comp(fp, is_pipe);
00626     return base;
00627 }
00628 
00629 int
00630 ngram_model_arpa_write(ngram_model_t *model,
00631                        const char *file_name)
00632 {
00633     return -1;
00634 }
00635 
00636 static int
00637 ngram_model_arpa_apply_weights(ngram_model_t *base, float32 lw,
00638                               float32 wip, float32 uw)
00639 {
00640     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00641     lm3g_apply_weights(base, &model->lm3g, lw, wip, uw);
00642     return 0;
00643 }
00644 
00645 /* Locate a specific bigram within a bigram list */
00646 #define BINARY_SEARCH_THRESH    16
00647 static int32
00648 find_bg(bigram_t * bg, int32 n, int32 w)
00649 {
00650     int32 i, b, e;
00651 
00652     /* Binary search until segment size < threshold */
00653     b = 0;
00654     e = n;
00655     while (e - b > BINARY_SEARCH_THRESH) {
00656         i = (b + e) >> 1;
00657         if ((int32)bg[i].wid < w)
00658             b = i + 1;
00659         else if ((int32)bg[i].wid > w)
00660             e = i;
00661         else
00662             return i;
00663     }
00664 
00665     /* Linear search within narrowed segment */
00666     for (i = b; (i < e) && (bg[i].wid != w); i++);
00667     return ((i < e) ? i : -1);
00668 }
00669 
00670 static int32
00671 lm3g_bg_score(ngram_model_arpa_t *model, int32 lw1,
00672               int32 lw2, int32 *n_used)
00673 {
00674     int32 i, n, b, score;
00675     bigram_t *bg;
00676 
00677     if (lw1 < 0) {
00678         *n_used = 1;
00679         return model->lm3g.unigrams[lw2].prob1.l;
00680     }
00681 
00682     b = FIRST_BG(model, lw1);
00683     n = FIRST_BG(model, lw1 + 1) - b;
00684     bg = model->lm3g.bigrams + b;
00685 
00686     if ((i = find_bg(bg, n, lw2)) >= 0) {
00687         /* Access mode = bigram */
00688         *n_used = 2;
00689         score = model->lm3g.prob2[bg[i].prob2].l;
00690     }
00691     else {
00692         /* Access mode = unigram */
00693         *n_used = 1;
00694         score = model->lm3g.unigrams[lw1].bo_wt1.l + model->lm3g.unigrams[lw2].prob1.l;
00695     }
00696 
00697     return (score);
00698 }
00699 
00700 static void
00701 load_tginfo(ngram_model_arpa_t *model, int32 lw1, int32 lw2)
00702 {
00703     int32 i, n, b, t;
00704     bigram_t *bg;
00705     tginfo_t *tginfo;
00706 
00707     /* First allocate space for tg information for bg lw1,lw2 */
00708     tginfo = (tginfo_t *) listelem_malloc(model->lm3g.le);
00709     tginfo->w1 = lw1;
00710     tginfo->tg = NULL;
00711     tginfo->next = model->lm3g.tginfo[lw2];
00712     model->lm3g.tginfo[lw2] = tginfo;
00713 
00714     /* Locate bigram lw1,lw2 */
00715     b = model->lm3g.unigrams[lw1].bigrams;
00716     n = model->lm3g.unigrams[lw1 + 1].bigrams - b;
00717     bg = model->lm3g.bigrams + b;
00718 
00719     if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) {
00720         tginfo->bowt = model->lm3g.bo_wt2[bg[i].bo_wt2].l;
00721 
00722         /* Find t = Absolute first trigram index for bigram lw1,lw2 */
00723         b += i;                 /* b = Absolute index of bigram lw1,lw2 on disk */
00724         t = FIRST_TG(model, b);
00725 
00726         tginfo->tg = model->lm3g.trigrams + t;
00727 
00728         /* Find #tg for bigram w1,w2 */
00729         tginfo->n_tg = FIRST_TG(model, b + 1) - t;
00730     }
00731     else {                      /* No bigram w1,w2 */
00732         tginfo->bowt = 0;
00733         tginfo->n_tg = 0;
00734     }
00735 }
00736 
00737 /* Similar to find_bg */
00738 static int32
00739 find_tg(trigram_t * tg, int32 n, int32 w)
00740 {
00741     int32 i, b, e;
00742 
00743     b = 0;
00744     e = n;
00745     while (e - b > BINARY_SEARCH_THRESH) {
00746         i = (b + e) >> 1;
00747         if ((int32)tg[i].wid < w)
00748             b = i + 1;
00749         else if ((int32)tg[i].wid > w)
00750             e = i;
00751         else
00752             return i;
00753     }
00754 
00755     for (i = b; (i < e) && (tg[i].wid != w); i++);
00756     return ((i < e) ? i : -1);
00757 }
00758 
00759 static int32
00760 lm3g_tg_score(ngram_model_arpa_t *model, int32 lw1,
00761               int32 lw2, int32 lw3, int32 *n_used)
00762 {
00763     ngram_model_t *base = &model->base;
00764     int32 i, n, score;
00765     trigram_t *tg;
00766     tginfo_t *tginfo, *prev_tginfo;
00767 
00768     if ((base->n < 3) || (lw1 < 0))
00769         return (lm3g_bg_score(model, lw2, lw3, n_used));
00770 
00771     prev_tginfo = NULL;
00772     for (tginfo = model->lm3g.tginfo[lw2]; tginfo; tginfo = tginfo->next) {
00773         if (tginfo->w1 == lw1)
00774             break;
00775         prev_tginfo = tginfo;
00776     }
00777 
00778     if (!tginfo) {
00779         load_tginfo(model, lw1, lw2);
00780         tginfo = model->lm3g.tginfo[lw2];
00781     }
00782     else if (prev_tginfo) {
00783         prev_tginfo->next = tginfo->next;
00784         tginfo->next = model->lm3g.tginfo[lw2];
00785         model->lm3g.tginfo[lw2] = tginfo;
00786     }
00787 
00788     tginfo->used = 1;
00789 
00790     /* Trigrams for w1,w2 now pointed to by tginfo */
00791     n = tginfo->n_tg;
00792     tg = tginfo->tg;
00793     if ((i = find_tg(tg, n, lw3)) >= 0) {
00794         /* Access mode = trigram */
00795         *n_used = 3;
00796         score = model->lm3g.prob3[tg[i].prob3].l;
00797     }
00798     else {
00799         score = tginfo->bowt + lm3g_bg_score(model, lw2, lw3, n_used);
00800     }
00801 
00802     return (score);
00803 }
00804 
00805 static int32
00806 ngram_model_arpa_score(ngram_model_t *base, int32 wid,
00807                        int32 *history, int32 n_hist,
00808                        int32 *n_used)
00809 {
00810     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00811 
00812     switch (n_hist) {
00813     case 0:
00814         /* Access mode: unigram */
00815         *n_used = 1;
00816         return model->lm3g.unigrams[wid].prob1.l;
00817     case 1:
00818         return lm3g_bg_score(model, history[0], wid, n_used);
00819     case 2:
00820     default:
00821         /* Anything greater than 2 is the same as a trigram for now. */
00822         return lm3g_tg_score(model, history[1], history[0], wid, n_used);
00823     }
00824 }
00825 
00826 static int32
00827 ngram_model_arpa_raw_score(ngram_model_t *base, int32 wid,
00828                            int32 *history, int32 n_hist,
00829                            int32 *n_used)
00830 {
00831     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00832     int32 score;
00833 
00834     switch (n_hist) {
00835     case 0:
00836         /* Access mode: unigram */
00837         *n_used = 1;
00838         /* Undo insertion penalty. */
00839         score = model->lm3g.unigrams[wid].prob1.l - base->log_wip;
00840         /* Undo language weight. */
00841         score = (int32)(score / base->lw);
00842         /* Undo unigram interpolation */
00843         if (strcmp(base->word_str[wid], "<s>") != 0) { /* FIXME: configurable start_sym */
00844             score = logmath_log(base->lmath,
00845                                 logmath_exp(base->lmath, score)
00846                                 - logmath_exp(base->lmath, 
00847                                               base->log_uniform + base->log_uniform_weight));
00848         }
00849         return score;
00850     case 1:
00851         score = lm3g_bg_score(model, history[0], wid, n_used);
00852         break;
00853     case 2:
00854     default:
00855         /* Anything greater than 2 is the same as a trigram for now. */
00856         score = lm3g_tg_score(model, history[1], history[0], wid, n_used);
00857         break;
00858     }
00859     /* FIXME (maybe): This doesn't undo unigram weighting in backoff cases. */
00860     return (int32)((score - base->log_wip) / base->lw);
00861 }
00862 
00863 static int32
00864 ngram_model_arpa_add_ug(ngram_model_t *base,
00865                         int32 wid, int32 lweight)
00866 {
00867     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00868     return lm3g_add_ug(base, &model->lm3g, wid, lweight);
00869 }
00870 
00871 static void
00872 ngram_model_arpa_free(ngram_model_t *base)
00873 {
00874     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00875     ckd_free(model->lm3g.unigrams);
00876     ckd_free(model->lm3g.bigrams);
00877     ckd_free(model->lm3g.trigrams);
00878     ckd_free(model->lm3g.prob2);
00879     ckd_free(model->lm3g.bo_wt2);
00880     ckd_free(model->lm3g.prob3);
00881     lm3g_tginfo_free(base, &model->lm3g);
00882     ckd_free(model->lm3g.tseg_base);
00883 }
00884 
00885 static void
00886 ngram_model_arpa_flush(ngram_model_t *base)
00887 {
00888     ngram_model_arpa_t *model = (ngram_model_arpa_t *)base;
00889     lm3g_tginfo_reset(base, &model->lm3g);
00890 }
00891 
00892 static ngram_funcs_t ngram_model_arpa_funcs = {
00893     ngram_model_arpa_free,          /* free */
00894     ngram_model_arpa_apply_weights, /* apply_weights */
00895     ngram_model_arpa_score,         /* score */
00896     ngram_model_arpa_raw_score,     /* raw_score */
00897     ngram_model_arpa_add_ug,        /* add_ug */
00898     ngram_model_arpa_flush          /* flush */
00899 };

Generated on Fri Jan 14 2011 for SphinxBase by  doxygen 1.7.1