00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00042
00043 #include <string.h>
00044 #include <assert.h>
00045
00046
00047 #include <ckd_alloc.h>
00048 #include <listelem_alloc.h>
00049 #include <err.h>
00050
00051
00052 #include "pocketsphinx_internal.h"
00053 #include "ps_lattice_internal.h"
00054 #include "ngram_search.h"
00055 #include "ngram_search_fwdtree.h"
00056 #include "ngram_search_fwdflat.h"
00057
00058 static int ngram_search_start(ps_search_t *search);
00059 static int ngram_search_step(ps_search_t *search, int frame_idx);
00060 static int ngram_search_finish(ps_search_t *search);
00061 static int ngram_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p);
00062 static char const *ngram_search_hyp(ps_search_t *search, int32 *out_score);
00063 static int32 ngram_search_prob(ps_search_t *search);
00064 static ps_seg_t *ngram_search_seg_iter(ps_search_t *search, int32 *out_score);
00065
00066 static ps_searchfuncs_t ngram_funcs = {
00067 "ngram",
00068 ngram_search_start,
00069 ngram_search_step,
00070 ngram_search_finish,
00071 ngram_search_reinit,
00072 ngram_search_free,
00073 ngram_search_lattice,
00074 ngram_search_hyp,
00075 ngram_search_prob,
00076 ngram_search_seg_iter,
00077 };
00078
00079 static void
00080 ngram_search_update_widmap(ngram_search_t *ngs)
00081 {
00082 const char **words;
00083 int32 i, n_words;
00084
00085
00086 n_words = ps_search_n_words(ngs);
00087 words = ckd_calloc(n_words, sizeof(*words));
00088
00089 for (i = 0; i < n_words; ++i)
00090 words[i] = (const char *)dict_wordstr(ps_search_dict(ngs), i);
00091 ngram_model_set_map_words(ngs->lmset, words, n_words);
00092 ckd_free(words);
00093 }
00094
00095 static void
00096 ngram_search_calc_beams(ngram_search_t *ngs)
00097 {
00098 cmd_ln_t *config;
00099 acmod_t *acmod;
00100
00101 config = ps_search_config(ngs);
00102 acmod = ps_search_acmod(ngs);
00103
00104
00105 ngs->beam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-beam"));
00106 ngs->wbeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-wbeam"));
00107 ngs->pbeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-pbeam"));
00108 ngs->lpbeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-lpbeam"));
00109 ngs->lponlybeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-lponlybeam"));
00110 ngs->fwdflatbeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-fwdflatbeam"));
00111 ngs->fwdflatwbeam = logmath_log(acmod->lmath, cmd_ln_float64_r(config, "-fwdflatwbeam"));
00112
00113
00114 ngs->maxwpf = cmd_ln_int32_r(config, "-maxwpf");
00115 ngs->maxhmmpf = cmd_ln_int32_r(config, "-maxhmmpf");
00116
00117
00118 ngs->wip = logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-wip"));
00119 ngs->nwpen = logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-nwpen"));
00120 ngs->pip = logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-pip"));
00121 ngs->silpen = ngs->pip
00122 + logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-silprob"));
00123 ngs->fillpen = ngs->pip
00124 + logmath_log(acmod->lmath, cmd_ln_float32_r(config, "-fillprob"));
00125
00126
00127 ngs->fwdflat_fwdtree_lw_ratio =
00128 cmd_ln_float32_r(config, "-fwdflatlw")
00129 / cmd_ln_float32_r(config, "-lw");
00130 ngs->bestpath_fwdtree_lw_ratio =
00131 cmd_ln_float32_r(config, "-bestpathlw")
00132 / cmd_ln_float32_r(config, "-lw");
00133
00134
00135 ngs->ascale = 1.0 / cmd_ln_float32_r(config, "-ascale");
00136 }
00137
00138 ps_search_t *
00139 ngram_search_init(cmd_ln_t *config,
00140 acmod_t *acmod,
00141 dict_t *dict,
00142 dict2pid_t *d2p)
00143 {
00144 ngram_search_t *ngs;
00145 const char *path;
00146
00147 ngs = ckd_calloc(1, sizeof(*ngs));
00148 ps_search_init(&ngs->base, &ngram_funcs, config, acmod, dict, d2p);
00149 ngs->hmmctx = hmm_context_init(bin_mdef_n_emit_state(acmod->mdef),
00150 acmod->tmat->tp, NULL, acmod->mdef->sseq);
00151 if (ngs->hmmctx == NULL) {
00152 ps_search_free(ps_search_base(ngs));
00153 return NULL;
00154 }
00155 ngs->chan_alloc = listelem_alloc_init(sizeof(chan_t));
00156 ngs->root_chan_alloc = listelem_alloc_init(sizeof(root_chan_t));
00157 ngs->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00158
00159
00160 ngram_search_calc_beams(ngs);
00161
00162
00163 ngs->word_chan = ckd_calloc(dict_size(dict),
00164 sizeof(*ngs->word_chan));
00165 ngs->word_lat_idx = ckd_calloc(dict_size(dict),
00166 sizeof(*ngs->word_lat_idx));
00167 ngs->zeroPermTab = ckd_calloc(bin_mdef_n_ciphone(acmod->mdef),
00168 sizeof(*ngs->zeroPermTab));
00169 ngs->word_active = bitvec_alloc(dict_size(dict));
00170 ngs->last_ltrans = ckd_calloc(dict_size(dict),
00171 sizeof(*ngs->last_ltrans));
00172
00173
00174
00175 ngs->bp_table_size = cmd_ln_int32_r(config, "-latsize");
00176 ngs->bp_table = ckd_calloc(ngs->bp_table_size,
00177 sizeof(*ngs->bp_table));
00178
00179 ngs->bscore_stack_size = ngs->bp_table_size * 20;
00180 ngs->bscore_stack = ckd_calloc(ngs->bscore_stack_size,
00181 sizeof(*ngs->bscore_stack));
00182 ngs->n_frame_alloc = 256;
00183 ngs->bp_table_idx = ckd_calloc(ngs->n_frame_alloc + 1,
00184 sizeof(*ngs->bp_table_idx));
00185 ++ngs->bp_table_idx;
00186
00187
00188 ngs->active_word_list = ckd_calloc_2d(2, dict_size(dict),
00189 sizeof(**ngs->active_word_list));
00190
00191
00192 if ((path = cmd_ln_str_r(config, "-lmctl"))) {
00193 ngs->lmset = ngram_model_set_read(config, path, acmod->lmath);
00194 if (ngs->lmset == NULL) {
00195 E_ERROR("Failed to read language model control file: %s\n",
00196 path);
00197 goto error_out;
00198 }
00199
00200 if ((path = cmd_ln_str_r(config, "-lmname"))) {
00201 ngram_model_set_select(ngs->lmset, path);
00202 }
00203 }
00204 else if ((path = cmd_ln_str_r(config, "-lm"))) {
00205 static const char *name = "default";
00206 ngram_model_t *lm;
00207
00208 lm = ngram_model_read(config, path, NGRAM_AUTO, acmod->lmath);
00209 if (lm == NULL) {
00210 E_ERROR("Failed to read language model file: %s\n", path);
00211 goto error_out;
00212 }
00213 ngs->lmset = ngram_model_set_init(config,
00214 &lm, (char **)&name,
00215 NULL, 1);
00216 if (ngs->lmset == NULL) {
00217 E_ERROR("Failed to initialize language model set\n");
00218 goto error_out;
00219 }
00220 }
00221
00222
00223 ngram_search_update_widmap(ngs);
00224
00225
00226 if (cmd_ln_boolean_r(config, "-fwdtree")) {
00227 ngram_fwdtree_init(ngs);
00228 ngs->fwdtree = TRUE;
00229 }
00230 if (cmd_ln_boolean_r(config, "-fwdflat")) {
00231 ngram_fwdflat_init(ngs);
00232 ngs->fwdflat = TRUE;
00233 }
00234 if (cmd_ln_boolean_r(config, "-bestpath")) {
00235 ngs->bestpath = TRUE;
00236 }
00237 return (ps_search_t *)ngs;
00238
00239 error_out:
00240 ngram_search_free((ps_search_t *)ngs);
00241 return NULL;
00242 }
00243
00244 static int
00245 ngram_search_reinit(ps_search_t *search, dict_t *dict, dict2pid_t *d2p)
00246 {
00247 ngram_search_t *ngs = (ngram_search_t *)search;
00248 int old_n_words;
00249 int rv = 0;
00250
00251
00252 old_n_words = search->n_words;
00253 if (old_n_words != dict_size(dict)) {
00254 search->n_words = dict_size(dict);
00255
00256 ckd_free(ngs->word_lat_idx);
00257 ckd_free(ngs->word_active);
00258 ckd_free(ngs->last_ltrans);
00259 ckd_free_2d(ngs->active_word_list);
00260 ngs->word_lat_idx = ckd_calloc(search->n_words, sizeof(*ngs->word_lat_idx));
00261 ngs->word_active = bitvec_alloc(search->n_words);
00262 ngs->last_ltrans = ckd_calloc(search->n_words, sizeof(*ngs->last_ltrans));
00263 ngs->active_word_list
00264 = ckd_calloc_2d(2, search->n_words,
00265 sizeof(**ngs->active_word_list));
00266 }
00267
00268
00269 ps_search_base_reinit(search, dict, d2p);
00270
00271
00272 ngram_search_calc_beams(ngs);
00273
00274
00275 ngram_search_update_widmap(ngs);
00276
00277
00278 if (ngs->fwdtree) {
00279 if ((rv = ngram_fwdtree_reinit(ngs)) < 0)
00280 return rv;
00281 }
00282 if (ngs->fwdflat) {
00283 if ((rv = ngram_fwdflat_reinit(ngs)) < 0)
00284 return rv;
00285 }
00286
00287 return rv;
00288 }
00289
00290 void
00291 ngram_search_free(ps_search_t *search)
00292 {
00293 ngram_search_t *ngs = (ngram_search_t *)search;
00294
00295 ps_search_deinit(search);
00296 if (ngs->fwdtree)
00297 ngram_fwdtree_deinit(ngs);
00298 if (ngs->fwdflat)
00299 ngram_fwdflat_deinit(ngs);
00300
00301 hmm_context_free(ngs->hmmctx);
00302 listelem_alloc_free(ngs->chan_alloc);
00303 listelem_alloc_free(ngs->root_chan_alloc);
00304 listelem_alloc_free(ngs->latnode_alloc);
00305 ngram_model_free(ngs->lmset);
00306
00307 ckd_free(ngs->word_chan);
00308 ckd_free(ngs->word_lat_idx);
00309 ckd_free(ngs->zeroPermTab);
00310 bitvec_free(ngs->word_active);
00311 ckd_free(ngs->bp_table);
00312 ckd_free(ngs->bscore_stack);
00313 if (ngs->bp_table_idx != NULL)
00314 ckd_free(ngs->bp_table_idx - 1);
00315 ckd_free_2d(ngs->active_word_list);
00316 ckd_free(ngs->last_ltrans);
00317 ckd_free(ngs);
00318 }
00319
00320 int
00321 ngram_search_mark_bptable(ngram_search_t *ngs, int frame_idx)
00322 {
00323 if (frame_idx >= ngs->n_frame_alloc) {
00324 ngs->n_frame_alloc *= 2;
00325 ngs->bp_table_idx = ckd_realloc(ngs->bp_table_idx - 1,
00326 (ngs->n_frame_alloc + 1)
00327 * sizeof(*ngs->bp_table_idx));
00328 if (ngs->frm_wordlist) {
00329 ngs->frm_wordlist = ckd_realloc(ngs->frm_wordlist,
00330 ngs->n_frame_alloc
00331 * sizeof(*ngs->frm_wordlist));
00332 }
00333 ++ngs->bp_table_idx;
00334 }
00335 ngs->bp_table_idx[frame_idx] = ngs->bpidx;
00336 return ngs->bpidx;
00337 }
00338
00342 static void
00343 cache_bptable_paths(ngram_search_t *ngs, int32 bp)
00344 {
00345 int32 w, prev_bp;
00346 bptbl_t *be;
00347
00348 assert(bp != NO_BP);
00349
00350 be = &(ngs->bp_table[bp]);
00351 prev_bp = bp;
00352 w = be->wid;
00353
00354 while (dict_filler_word(ps_search_dict(ngs), w)) {
00355 prev_bp = ngs->bp_table[prev_bp].bp;
00356 if (prev_bp == NO_BP)
00357 return;
00358 w = ngs->bp_table[prev_bp].wid;
00359 }
00360
00361 be->real_wid = dict_basewid(ps_search_dict(ngs), w);
00362
00363 prev_bp = ngs->bp_table[prev_bp].bp;
00364 be->prev_real_wid =
00365 (prev_bp != NO_BP) ? ngs->bp_table[prev_bp].real_wid : -1;
00366 }
00367
00368 void
00369 ngram_search_save_bp(ngram_search_t *ngs, int frame_idx,
00370 int32 w, int32 score, int32 path, int32 rc)
00371 {
00372 int32 _bp_;
00373
00374
00375 _bp_ = ngs->word_lat_idx[w];
00376 if (_bp_ != NO_BP) {
00377
00378
00379 if (ngs->bp_table[_bp_].score WORSE_THAN score) {
00380 if (ngs->bp_table[_bp_].bp != path) {
00381 ngs->bp_table[_bp_].bp = path;
00382 cache_bptable_paths(ngs, _bp_);
00383 }
00384 ngs->bp_table[_bp_].score = score;
00385 }
00386
00387
00388
00389 ngs->bscore_stack[ngs->bp_table[_bp_].s_idx + rc] = score;
00390 }
00391 else {
00392 int32 i, rcsize, *bss;
00393 bptbl_t *be;
00394
00395
00396 if (ngs->bpidx == NO_BP) {
00397 E_ERROR("No entries in backpointer table!");
00398 return;
00399 }
00400
00401
00402 if (ngs->bpidx >= ngs->bp_table_size) {
00403 ngs->bp_table_size *= 2;
00404 ngs->bp_table = ckd_realloc(ngs->bp_table,
00405 ngs->bp_table_size
00406 * sizeof(*ngs->bp_table));
00407 E_INFO("Resized backpointer table to %d entries\n", ngs->bp_table_size);
00408 }
00409 if (ngs->bss_head >= ngs->bscore_stack_size
00410 - bin_mdef_n_ciphone(ps_search_acmod(ngs)->mdef)) {
00411 ngs->bscore_stack_size *= 2;
00412 ngs->bscore_stack = ckd_realloc(ngs->bscore_stack,
00413 ngs->bscore_stack_size
00414 * sizeof(*ngs->bscore_stack));
00415 E_INFO("Resized score stack to %d entries\n", ngs->bscore_stack_size);
00416 }
00417
00418 ngs->word_lat_idx[w] = ngs->bpidx;
00419 be = &(ngs->bp_table[ngs->bpidx]);
00420 be->wid = w;
00421 be->frame = frame_idx;
00422 be->bp = path;
00423 be->score = score;
00424 be->s_idx = ngs->bss_head;
00425 be->valid = TRUE;
00426
00427
00428
00429 be->last_phone = dict_last_phone(ps_search_dict(ngs),w);
00430 if (dict_is_single_phone(ps_search_dict(ngs), w)) {
00431 be->last2_phone = -1;
00432 rcsize = 1;
00433 }
00434 else {
00435 be->last2_phone = dict_second_last_phone(ps_search_dict(ngs),w);
00436 rcsize = dict2pid_rssid(ps_search_dict2pid(ngs),
00437 be->last_phone, be->last2_phone)->n_ssid;
00438 }
00439
00440 for (i = rcsize, bss = ngs->bscore_stack + ngs->bss_head; i > 0; --i, bss++)
00441 *bss = WORST_SCORE;
00442 ngs->bscore_stack[ngs->bss_head + rc] = score;
00443 cache_bptable_paths(ngs, ngs->bpidx);
00444
00445 ngs->bpidx++;
00446 ngs->bss_head += rcsize;
00447 }
00448 }
00449
00450 int
00451 ngram_search_find_exit(ngram_search_t *ngs, int frame_idx, int32 *out_best_score)
00452 {
00453
00454 int end_bpidx;
00455 int best_exit, bp;
00456 int32 best_score;
00457
00458
00459 if (ngs->n_frame == 0)
00460 return NO_BP;
00461
00462 if (frame_idx == -1 || frame_idx >= ngs->n_frame)
00463 frame_idx = ngs->n_frame - 1;
00464 end_bpidx = ngs->bp_table_idx[frame_idx];
00465
00466 best_score = WORST_SCORE;
00467 best_exit = NO_BP;
00468
00469
00470 while (frame_idx >= 0 && ngs->bp_table_idx[frame_idx] == end_bpidx)
00471 --frame_idx;
00472
00473 if (frame_idx < 0)
00474 return NO_BP;
00475
00476
00477 assert(end_bpidx < ngs->bp_table_size);
00478 for (bp = ngs->bp_table_idx[frame_idx]; bp < end_bpidx; ++bp) {
00479 if (ngs->bp_table[bp].wid == ps_search_finish_wid(ngs)
00480 || ngs->bp_table[bp].score BETTER_THAN best_score) {
00481 best_score = ngs->bp_table[bp].score;
00482 best_exit = bp;
00483 }
00484 if (ngs->bp_table[bp].wid == ps_search_finish_wid(ngs))
00485 break;
00486 }
00487
00488 if (out_best_score) *out_best_score = best_score;
00489 return best_exit;
00490 }
00491
00492 char const *
00493 ngram_search_bp_hyp(ngram_search_t *ngs, int bpidx)
00494 {
00495 ps_search_t *base = ps_search_base(ngs);
00496 char *c;
00497 size_t len;
00498 int bp;
00499
00500 if (bpidx == NO_BP)
00501 return NULL;
00502
00503 bp = bpidx;
00504 len = 0;
00505 while (bp != NO_BP) {
00506 bptbl_t *be = &ngs->bp_table[bp];
00507 bp = be->bp;
00508 if (dict_real_word(ps_search_dict(ngs), be->wid))
00509 len += strlen(dict_basestr(ps_search_dict(ngs), be->wid)) + 1;
00510 }
00511
00512 ckd_free(base->hyp_str);
00513 if (len == 0) {
00514 base->hyp_str = NULL;
00515 return base->hyp_str;
00516 }
00517 base->hyp_str = ckd_calloc(1, len);
00518
00519 bp = bpidx;
00520 c = base->hyp_str + len - 1;
00521 while (bp != NO_BP) {
00522 bptbl_t *be = &ngs->bp_table[bp];
00523 size_t len;
00524
00525 bp = be->bp;
00526 if (dict_real_word(ps_search_dict(ngs), be->wid)) {
00527 len = strlen(dict_basestr(ps_search_dict(ngs), be->wid));
00528 c -= len;
00529 memcpy(c, dict_basestr(ps_search_dict(ngs), be->wid), len);
00530 if (c > base->hyp_str) {
00531 --c;
00532 *c = ' ';
00533 }
00534 }
00535 }
00536
00537 return base->hyp_str;
00538 }
00539
00540 void
00541 ngram_search_alloc_all_rc(ngram_search_t *ngs, int32 w)
00542 {
00543 chan_t *hmm, *thmm;
00544 xwdssid_t *rssid;
00545 int32 i;
00546
00547
00548
00549 assert(!dict_is_single_phone(ps_search_dict(ngs), w));
00550 rssid = dict2pid_rssid(ps_search_dict2pid(ngs),
00551 dict_last_phone(ps_search_dict(ngs),w),
00552 dict_second_last_phone(ps_search_dict(ngs),w));
00553 hmm = ngs->word_chan[w];
00554 if ((hmm == NULL) || (hmm_nonmpx_ssid(&hmm->hmm) != rssid->ssid[0])) {
00555 hmm = listelem_malloc(ngs->chan_alloc);
00556 hmm->next = ngs->word_chan[w];
00557 ngs->word_chan[w] = hmm;
00558
00559 hmm->info.rc_id = 0;
00560 hmm->ciphone = dict_last_phone(ps_search_dict(ngs),w);
00561 hmm_init(ngs->hmmctx, &hmm->hmm, FALSE, rssid->ssid[0], hmm->ciphone);
00562 E_DEBUG(3,("allocated rc_id 0 ssid %d ciphone %d lc %d word %s\n",
00563 rssid->ssid[0], hmm->ciphone,
00564 dict_second_last_phone(ps_search_dict(ngs),w),
00565 dict_wordstr(ps_search_dict(ngs),w)));
00566 }
00567 for (i = 1; i < rssid->n_ssid; ++i) {
00568 if ((hmm->next == NULL) || (hmm_nonmpx_ssid(&hmm->next->hmm) != rssid->ssid[i])) {
00569 thmm = listelem_malloc(ngs->chan_alloc);
00570 thmm->next = hmm->next;
00571 hmm->next = thmm;
00572 hmm = thmm;
00573
00574 hmm->info.rc_id = i;
00575 hmm->ciphone = dict_last_phone(ps_search_dict(ngs),w);
00576 hmm_init(ngs->hmmctx, &hmm->hmm, FALSE, rssid->ssid[i], hmm->ciphone);
00577 E_DEBUG(3,("allocated rc_id %d ssid %d ciphone %d lc %d word %s\n",
00578 i, rssid->ssid[i], hmm->ciphone,
00579 dict_second_last_phone(ps_search_dict(ngs),w),
00580 dict_wordstr(ps_search_dict(ngs),w)));
00581 }
00582 else
00583 hmm = hmm->next;
00584 }
00585 }
00586
00587 void
00588 ngram_search_free_all_rc(ngram_search_t *ngs, int32 w)
00589 {
00590 chan_t *hmm, *thmm;
00591
00592 for (hmm = ngs->word_chan[w]; hmm; hmm = thmm) {
00593 thmm = hmm->next;
00594 hmm_deinit(&hmm->hmm);
00595 listelem_free(ngs->chan_alloc, hmm);
00596 }
00597 ngs->word_chan[w] = NULL;
00598 }
00599
00600 int32
00601 ngram_search_exit_score(ngram_search_t *ngs, bptbl_t *pbe, int rcphone)
00602 {
00603
00604
00605
00606 if (pbe->last2_phone == -1) {
00607
00608 return ngs->bscore_stack[pbe->s_idx];
00609 }
00610 else {
00611 xwdssid_t *rssid;
00612
00613
00614 rssid = dict2pid_rssid(ps_search_dict2pid(ngs),
00615 pbe->last_phone, pbe->last2_phone);
00616 return ngs->bscore_stack[pbe->s_idx + rssid->cimap[rcphone]];
00617 }
00618 }
00619
00620
00621
00622
00623 void
00624 ngram_compute_seg_score(ngram_search_t *ngs, bptbl_t *be, float32 lwf,
00625 int32 *out_ascr, int32 *out_lscr)
00626 {
00627 bptbl_t *pbe;
00628 int32 start_score;
00629
00630
00631 if (be->bp == NO_BP) {
00632 *out_ascr = be->score;
00633 *out_lscr = 0;
00634 return;
00635 }
00636
00637
00638 pbe = ngs->bp_table + be->bp;
00639 start_score = ngram_search_exit_score(ngs, pbe,
00640 dict_first_phone(ps_search_dict(ngs),be->wid));
00641
00642
00643
00644
00645
00646 if (start_score == WORST_SCORE)
00647 start_score = 0;
00648
00649
00650
00651
00652
00653 if (be->wid == ps_search_silence_wid(ngs)) {
00654 *out_lscr = ngs->silpen;
00655 }
00656 else if (dict_filler_word(ps_search_dict(ngs), be->wid)) {
00657 *out_lscr = ngs->fillpen;
00658 }
00659 else {
00660 int32 n_used;
00661 *out_lscr = ngram_tg_score(ngs->lmset,
00662 be->real_wid,
00663 pbe->real_wid,
00664 pbe->prev_real_wid, &n_used);
00665 *out_lscr = *out_lscr * lwf;
00666 }
00667 *out_ascr = be->score - start_score - *out_lscr;
00668 }
00669
00670 static int
00671 ngram_search_start(ps_search_t *search)
00672 {
00673 ngram_search_t *ngs = (ngram_search_t *)search;
00674
00675 ngs->done = FALSE;
00676 ngram_model_flush(ngs->lmset);
00677 if (ngs->fwdtree)
00678 ngram_fwdtree_start(ngs);
00679 else if (ngs->fwdflat)
00680 ngram_fwdflat_start(ngs);
00681 else
00682 return -1;
00683 return 0;
00684 }
00685
00686 static int
00687 ngram_search_step(ps_search_t *search, int frame_idx)
00688 {
00689 ngram_search_t *ngs = (ngram_search_t *)search;
00690
00691 if (ngs->fwdtree)
00692 return ngram_fwdtree_search(ngs, frame_idx);
00693 else if (ngs->fwdflat)
00694 return ngram_fwdflat_search(ngs, frame_idx);
00695 else
00696 return -1;
00697 }
00698
00699 static void
00700 dump_bptable(ngram_search_t *ngs)
00701 {
00702 int i;
00703 E_INFO("Backpointer table (%d entries):\n", ngs->bpidx);
00704 for (i = 0; i < ngs->bpidx; ++i) {
00705 E_INFO_NOFN("%-5d %-10s start %-3d end %-3d score %-8d bp\n",
00706 i, dict_wordstr(ps_search_dict(ngs), ngs->bp_table[i].wid),
00707 ngs->bp_table[i].bp == -1 ? 0 :
00708 ngs->bp_table[ngs->bp_table[i].bp].frame + 1,
00709 ngs->bp_table[i].frame,
00710 ngs->bp_table[i].score,
00711 ngs->bp_table[i].bp);
00712
00713 }
00714 }
00715
00716 static int
00717 ngram_search_finish(ps_search_t *search)
00718 {
00719 ngram_search_t *ngs = (ngram_search_t *)search;
00720
00721 if (ngs->fwdtree) {
00722 ngram_fwdtree_finish(ngs);
00723
00724
00725
00726 if (ngs->fwdflat) {
00727 int i;
00728
00729 if (acmod_rewind(ps_search_acmod(ngs)) < 0)
00730 return -1;
00731
00732 ngram_fwdflat_start(ngs);
00733 i = 0;
00734 while (ps_search_acmod(ngs)->n_feat_frame > 0) {
00735 int nfr;
00736 if ((nfr = ngram_fwdflat_search(ngs, i)) < 0)
00737 return nfr;
00738 acmod_advance(ps_search_acmod(ngs));
00739 ++i;
00740 }
00741 ngram_fwdflat_finish(ngs);
00742
00743
00744 }
00745 }
00746 else if (ngs->fwdflat) {
00747 ngram_fwdflat_finish(ngs);
00748 }
00749
00750
00751 ngs->done = TRUE;
00752 return 0;
00753 }
00754
00755 static ps_latlink_t *
00756 ngram_search_bestpath(ps_search_t *search, int32 *out_score, int backward)
00757 {
00758 ngram_search_t *ngs = (ngram_search_t *)search;
00759
00760 if (search->last_link == NULL) {
00761 search->last_link = ps_lattice_bestpath(search->dag, ngs->lmset,
00762 ngs->bestpath_fwdtree_lw_ratio,
00763 ngs->ascale);
00764 if (search->last_link == NULL)
00765 return NULL;
00766
00767
00768 if (search->post == 0)
00769 search->post = ps_lattice_posterior(search->dag, ngs->lmset,
00770 ngs->ascale);
00771 }
00772 if (out_score)
00773 *out_score = search->last_link->path_scr + search->dag->final_node_ascr;
00774 return search->last_link;
00775 }
00776
00777 static char const *
00778 ngram_search_hyp(ps_search_t *search, int32 *out_score)
00779 {
00780 ngram_search_t *ngs = (ngram_search_t *)search;
00781
00782
00783 if (ngs->bestpath && ngs->done) {
00784 ps_lattice_t *dag;
00785 ps_latlink_t *link;
00786
00787 if ((dag = ngram_search_lattice(search)) == NULL)
00788 return NULL;
00789 if ((link = ngram_search_bestpath(search, out_score, FALSE)) == NULL)
00790 return NULL;
00791 return ps_lattice_hyp(dag, link);
00792 }
00793 else {
00794 int32 bpidx;
00795
00796
00797 bpidx = ngram_search_find_exit(ngs, -1, out_score);
00798 if (bpidx != NO_BP)
00799 return ngram_search_bp_hyp(ngs, bpidx);
00800 }
00801
00802 return NULL;
00803 }
00804
00805 static void
00806 ngram_search_bp2itor(ps_seg_t *seg, int bp)
00807 {
00808 ngram_search_t *ngs = (ngram_search_t *)seg->search;
00809 bptbl_t *be, *pbe;
00810
00811 be = &ngs->bp_table[bp];
00812 pbe = be->bp == -1 ? NULL : &ngs->bp_table[be->bp];
00813 seg->word = dict_wordstr(ps_search_dict(ngs), be->wid);
00814 seg->ef = be->frame;
00815 seg->sf = pbe ? pbe->frame + 1 : 0;
00816 seg->prob = 0;
00817
00818 if (pbe == NULL) {
00819 seg->ascr = be->score;
00820 seg->lscr = 0;
00821 seg->lback = 0;
00822 }
00823 else {
00824 int32 start_score;
00825
00826
00827 start_score = ngram_search_exit_score(ngs, pbe,
00828 dict_first_phone(ps_search_dict(ngs), be->wid));
00829 if (be->wid == ps_search_silence_wid(ngs)) {
00830 seg->lscr = ngs->silpen;
00831 }
00832 else if (dict_filler_word(ps_search_dict(ngs), be->wid)) {
00833 seg->lscr = ngs->fillpen;
00834 }
00835 else {
00836 seg->lscr = ngram_tg_score(ngs->lmset,
00837 be->real_wid,
00838 pbe->real_wid,
00839 pbe->prev_real_wid, &seg->lback);
00840 seg->lscr = (int32)(seg->lscr * seg->lwf);
00841 }
00842 seg->ascr = be->score - start_score - seg->lscr;
00843 }
00844 }
00845
00846 static void
00847 ngram_bp_seg_free(ps_seg_t *seg)
00848 {
00849 bptbl_seg_t *itor = (bptbl_seg_t *)seg;
00850
00851 ckd_free(itor->bpidx);
00852 ckd_free(itor);
00853 }
00854
00855 static ps_seg_t *
00856 ngram_bp_seg_next(ps_seg_t *seg)
00857 {
00858 bptbl_seg_t *itor = (bptbl_seg_t *)seg;
00859
00860 if (++itor->cur == itor->n_bpidx) {
00861 ngram_bp_seg_free(seg);
00862 return NULL;
00863 }
00864
00865 ngram_search_bp2itor(seg, itor->bpidx[itor->cur]);
00866 return seg;
00867 }
00868
00869 static ps_segfuncs_t ngram_bp_segfuncs = {
00870 ngram_bp_seg_next,
00871 ngram_bp_seg_free
00872 };
00873
00874 static ps_seg_t *
00875 ngram_search_bp_iter(ngram_search_t *ngs, int bpidx, float32 lwf)
00876 {
00877 bptbl_seg_t *itor;
00878 int bp, cur;
00879
00880
00881
00882
00883
00884 itor = ckd_calloc(1, sizeof(*itor));
00885 itor->base.vt = &ngram_bp_segfuncs;
00886 itor->base.search = ps_search_base(ngs);
00887 itor->base.lwf = lwf;
00888 itor->n_bpidx = 0;
00889 bp = bpidx;
00890 while (bp != NO_BP) {
00891 bptbl_t *be = &ngs->bp_table[bp];
00892 bp = be->bp;
00893 ++itor->n_bpidx;
00894 }
00895 if (itor->n_bpidx == 0) {
00896 ckd_free(itor);
00897 return NULL;
00898 }
00899 itor->bpidx = ckd_calloc(itor->n_bpidx, sizeof(*itor->bpidx));
00900 cur = itor->n_bpidx - 1;
00901 bp = bpidx;
00902 while (bp != NO_BP) {
00903 bptbl_t *be = &ngs->bp_table[bp];
00904 itor->bpidx[cur] = bp;
00905 bp = be->bp;
00906 --cur;
00907 }
00908
00909
00910 ngram_search_bp2itor((ps_seg_t *)itor, itor->bpidx[0]);
00911
00912 return (ps_seg_t *)itor;
00913 }
00914
00915 static ps_seg_t *
00916 ngram_search_seg_iter(ps_search_t *search, int32 *out_score)
00917 {
00918 ngram_search_t *ngs = (ngram_search_t *)search;
00919
00920
00921 if (ngs->bestpath && ngs->done) {
00922 ps_lattice_t *dag;
00923 ps_latlink_t *link;
00924
00925 if ((dag = ngram_search_lattice(search)) == NULL)
00926 return NULL;
00927 if ((link = ngram_search_bestpath(search, out_score, TRUE)) == NULL)
00928 return NULL;
00929 return ps_lattice_seg_iter(dag, link,
00930 ngs->bestpath_fwdtree_lw_ratio);
00931 }
00932 else {
00933 int32 bpidx;
00934
00935
00936 bpidx = ngram_search_find_exit(ngs, -1, out_score);
00937 return ngram_search_bp_iter(ngs, bpidx,
00938
00939 (ngs->done && ngs->fwdflat)
00940 ? ngs->fwdflat_fwdtree_lw_ratio : 1.0);
00941 }
00942
00943 return NULL;
00944 }
00945
00946 static int32
00947 ngram_search_prob(ps_search_t *search)
00948 {
00949 ngram_search_t *ngs = (ngram_search_t *)search;
00950
00951
00952 if (ngs->bestpath && ngs->done) {
00953 ps_lattice_t *dag;
00954 ps_latlink_t *link;
00955
00956 if ((dag = ngram_search_lattice(search)) == NULL)
00957 return 0;
00958 if ((link = ngram_search_bestpath(search, NULL, TRUE)) == NULL)
00959 return 0;
00960 return search->post;
00961 }
00962 else {
00963
00964 return 0;
00965 }
00966 }
00967
00968 static void
00969 create_dag_nodes(ngram_search_t *ngs, ps_lattice_t *dag)
00970 {
00971 bptbl_t *bp_ptr;
00972 int32 i;
00973
00974 for (i = 0, bp_ptr = ngs->bp_table; i < ngs->bpidx; ++i, ++bp_ptr) {
00975 int32 sf, ef, wid;
00976 ps_latnode_t *node;
00977
00978
00979 if (!bp_ptr->valid)
00980 continue;
00981
00982 sf = (bp_ptr->bp < 0) ? 0 : ngs->bp_table[bp_ptr->bp].frame + 1;
00983 ef = bp_ptr->frame;
00984 wid = bp_ptr->wid;
00985
00986 assert(ef < dag->n_frames);
00987
00988 if ((wid == ps_search_finish_wid(ngs)) && (ef < dag->n_frames - 1))
00989 continue;
00990
00991
00992 if ((!dict_filler_word(ps_search_dict(ngs), wid))
00993 && (!ngram_model_set_known_wid(ngs->lmset,
00994 dict_basewid(ps_search_dict(ngs), wid))))
00995 continue;
00996
00997
00998 for (node = dag->nodes; node; node = node->next) {
00999 if ((node->wid == wid) && (node->sf == sf))
01000 break;
01001 }
01002
01003
01004 if (node)
01005 node->lef = i;
01006 else {
01007
01008 node = listelem_malloc(dag->latnode_alloc);
01009 node->wid = wid;
01010 node->sf = sf;
01011 node->fef = node->lef = i;
01012 node->reachable = FALSE;
01013 node->entries = NULL;
01014 node->exits = NULL;
01015
01016 node->next = dag->nodes;
01017 dag->nodes = node;
01018 }
01019 }
01020 }
01021
01022 static ps_latnode_t *
01023 find_start_node(ngram_search_t *ngs, ps_lattice_t *dag)
01024 {
01025 ps_latnode_t *node;
01026
01027
01028 for (node = dag->nodes; node; node = node->next) {
01029 if ((node->wid == ps_search_start_wid(ngs)) && (node->sf == 0))
01030 break;
01031 }
01032 if (!node) {
01033
01034 E_ERROR("Couldn't find <s> in first frame\n");
01035 return NULL;
01036 }
01037 return node;
01038 }
01039
01040 static ps_latnode_t *
01041 find_end_node(ngram_search_t *ngs, ps_lattice_t *dag, float32 lwf)
01042 {
01043 ps_latnode_t *node;
01044 int32 ef, bestbp, bp, bestscore;
01045
01046
01047 for (node = dag->nodes; node; node = node->next) {
01048 int32 lef = ngs->bp_table[node->lef].frame;
01049 if ((node->wid == ps_search_finish_wid(ngs))
01050 && (lef == dag->n_frames - 1))
01051 break;
01052 }
01053 if (node != NULL)
01054 return node;
01055
01056
01057
01058
01059 for (ef = dag->n_frames - 1;
01060 ef >= 0 && ngs->bp_table_idx[ef] == ngs->bpidx;
01061 --ef);
01062 if (ef < 0) {
01063 E_ERROR("Empty backpointer table: can not build DAG.\n");
01064 return NULL;
01065 }
01066
01067
01068 bestscore = WORST_SCORE;
01069 bestbp = NO_BP;
01070 for (bp = ngs->bp_table_idx[ef]; bp < ngs->bp_table_idx[ef + 1]; ++bp) {
01071 int32 n_used, l_scr;
01072 l_scr = ngram_tg_score(ngs->lmset, ps_search_finish_wid(ngs),
01073 ngs->bp_table[bp].real_wid,
01074 ngs->bp_table[bp].prev_real_wid,
01075 &n_used);
01076 l_scr = l_scr * lwf;
01077
01078 if (ngs->bp_table[bp].score + l_scr BETTER_THAN bestscore) {
01079 bestscore = ngs->bp_table[bp].score + l_scr;
01080 bestbp = bp;
01081 }
01082 }
01083 if (bestbp == NO_BP) {
01084 E_ERROR("No word exits found in last frame, assuming no recognition\n");
01085 return NULL;
01086 }
01087 E_WARN("</s> not found in last frame, using %s instead\n",
01088 dict_basestr(ps_search_dict(ngs), ngs->bp_table[bestbp].wid));
01089
01090
01091 for (node = dag->nodes; node; node = node->next) {
01092 if (node->lef == bestbp)
01093 return node;
01094 }
01095
01096
01097 E_ERROR("Failed to find DAG node corresponding to %s\n",
01098 dict_basestr(ps_search_dict(ngs), ngs->bp_table[bestbp].wid));
01099 return NULL;
01100 }
01101
01102
01103
01104
01105 ps_lattice_t *
01106 ngram_search_lattice(ps_search_t *search)
01107 {
01108 int32 i, ef, lef, score, ascr, lscr;
01109 ps_latnode_t *node, *from, *to;
01110 ngram_search_t *ngs;
01111 ps_lattice_t *dag;
01112 float lwf;
01113
01114 ngs = (ngram_search_t *)search;
01115
01116
01117
01118 if (ngs->best_score == WORST_SCORE || ngs->best_score WORSE_THAN WORST_SCORE)
01119 return NULL;
01120
01121
01122
01123 if (search->dag && search->dag->n_frames == ngs->n_frame)
01124 return search->dag;
01125
01126
01127 ps_lattice_free(search->dag);
01128 search->dag = NULL;
01129 dag = ps_lattice_init_search(search, ngs->n_frame);
01130
01131 lwf = ngs->fwdflat ? ngs->fwdflat_fwdtree_lw_ratio : 1.0;
01132 create_dag_nodes(ngs, dag);
01133 if ((dag->start = find_start_node(ngs, dag)) == NULL)
01134 goto error_out;
01135 if ((dag->end = find_end_node(ngs, dag, ngs->bestpath_fwdtree_lw_ratio)) == NULL)
01136 goto error_out;
01137 E_INFO("lattice start node %s.%d end node %s.%d\n",
01138 dict_wordstr(search->dict, dag->start->wid), dag->start->sf,
01139 dict_wordstr(search->dict, dag->end->wid), dag->end->sf);
01140
01141 ngram_compute_seg_score(ngs, ngs->bp_table + dag->end->lef, lwf,
01142 &dag->final_node_ascr, &lscr);
01143
01144
01145
01146
01147
01148
01149
01150
01151 dag->end->reachable = TRUE;
01152 for (to = dag->end; to; to = to->next) {
01153
01154 if (!to->reachable)
01155 continue;
01156
01157
01158 for (from = to->next; from; from = from->next) {
01159 bptbl_t *bp_ptr;
01160
01161 ef = ngs->bp_table[from->fef].frame;
01162 lef = ngs->bp_table[from->lef].frame;
01163
01164 if ((to->sf <= ef) || (to->sf > lef + 1))
01165 continue;
01166
01167
01168 i = from->fef;
01169 bp_ptr = ngs->bp_table + i;
01170 for (; i <= from->lef; i++, bp_ptr++) {
01171 if (bp_ptr->wid != from->wid)
01172 continue;
01173 if (bp_ptr->frame >= to->sf - 1)
01174 break;
01175 }
01176
01177 if ((i > from->lef) || (bp_ptr->frame != to->sf - 1))
01178 continue;
01179
01180
01181 ngram_compute_seg_score(ngs, bp_ptr, lwf,
01182 &ascr, &lscr);
01183
01184 score = (ngram_search_exit_score(ngs, bp_ptr,
01185 dict_first_phone(ps_search_dict(ngs), to->wid))
01186 - bp_ptr->score + ascr);
01187 if (score BETTER_THAN 0) {
01188
01189
01190
01191
01192
01193 ps_lattice_link(dag, from, to, -424242, bp_ptr->frame);
01194 from->reachable = TRUE;
01195 }
01196 else if (score BETTER_THAN WORST_SCORE) {
01197 ps_lattice_link(dag, from, to, score, bp_ptr->frame);
01198 from->reachable = TRUE;
01199 }
01200 }
01201 }
01202
01203
01204 if (!dag->start->reachable) {
01205 E_ERROR("End node of lattice isolated; unreachable\n");
01206 goto error_out;
01207 }
01208
01209 for (node = dag->nodes; node; node = node->next) {
01210
01211 node->fef = ngs->bp_table[node->fef].frame;
01212 node->lef = ngs->bp_table[node->lef].frame;
01213
01214 node->basewid = dict_basewid(search->dict, node->wid);
01215 }
01216
01217
01218 for (node = dag->nodes; node; node = node->next) {
01219 ps_latnode_t *alt;
01220
01221 for (alt = node->next; alt && alt->sf == node->sf; alt = alt->next) {
01222 if (alt->basewid == node->basewid) {
01223 alt->alt = node->alt;
01224 node->alt = alt;
01225 break;
01226 }
01227 }
01228 }
01229
01230
01231
01232
01233 if (dict_filler_word(ps_search_dict(ngs), dag->end->wid))
01234 dag->end->basewid = ps_search_finish_wid(ngs);
01235
01236
01237 ps_lattice_delete_unreachable(dag);
01238
01239
01240
01241 ps_lattice_bypass_fillers(dag, ngs->silpen, ngs->fillpen);
01242
01243 search->dag = dag;
01244 return dag;
01245
01246 error_out:
01247 ps_lattice_free(dag);
01248 return NULL;
01249 }