• Main Page
  • Data Structures
  • Files
  • File List
  • Globals

src/libpocketsphinx/ps_lattice.c

Go to the documentation of this file.
00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 2008 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00042 /* System headers. */
00043 #include <assert.h>
00044 #include <string.h>
00045 #include <math.h>
00046 
00047 /* SphinxBase headers. */
00048 #include <ckd_alloc.h>
00049 #include <listelem_alloc.h>
00050 #include <strfuncs.h>
00051 #include <err.h>
00052 #include <pio.h>
00053 
00054 /* Local headers. */
00055 #include "pocketsphinx_internal.h"
00056 #include "ps_lattice_internal.h"
00057 #include "ngram_search.h"
00058 #include "dict.h"
00059 
00060 /*
00061  * Create a directed link between "from" and "to" nodes, but if a link already exists,
00062  * choose one with the best ascr.
00063  */
00064 void
00065 ps_lattice_link(ps_lattice_t *dag, ps_latnode_t *from, ps_latnode_t *to, int32 score, int32 ef)
00066 {
00067     latlink_list_t *fwdlink;
00068 
00069     /* Look for an existing link between "from" and "to" nodes */
00070     for (fwdlink = from->exits; fwdlink; fwdlink = fwdlink->next)
00071         if (fwdlink->link->to == to)
00072             break;
00073 
00074     if (fwdlink == NULL) {
00075         latlink_list_t *revlink;
00076         ps_latlink_t *link;
00077 
00078         /* No link between the two nodes; create a new one */
00079         link = listelem_malloc(dag->latlink_alloc);
00080         fwdlink = listelem_malloc(dag->latlink_list_alloc);
00081         revlink = listelem_malloc(dag->latlink_list_alloc);
00082 
00083         link->from = from;
00084         link->to = to;
00085         link->ascr = score;
00086         link->ef = ef;
00087         link->best_prev = NULL;
00088 
00089         fwdlink->link = revlink->link = link;
00090         fwdlink->next = from->exits;
00091         from->exits = fwdlink;
00092         revlink->next = to->entries;
00093         to->entries = revlink;
00094     }
00095     else {
00096         /* Link already exists; just retain the best ascr */
00097         if (score BETTER_THAN fwdlink->link->ascr) {
00098             fwdlink->link->ascr = score;
00099             fwdlink->link->ef = ef;
00100         }
00101     }
00102 }
00103 
00104 void
00105 ps_lattice_bypass_fillers(ps_lattice_t *dag, int32 silpen, int32 fillpen)
00106 {
00107     ps_latnode_t *node;
00108     int32 score;
00109 
00110     /* Bypass filler nodes */
00111     for (node = dag->nodes; node; node = node->next) {
00112         latlink_list_t *revlink;
00113         if (node == dag->end || !dict_filler_word(ps_search_dict(dag->search), node->basewid))
00114             continue;
00115 
00116         /* Replace each link entering filler node with links to all its successors */
00117         for (revlink = node->entries; revlink; revlink = revlink->next) {
00118             latlink_list_t *forlink;
00119             ps_latlink_t *rlink = revlink->link;
00120 
00121             score = (node->basewid == ps_search_silence_wid(dag->search)) ? silpen : fillpen;
00122             score += rlink->ascr;
00123             /*
00124              * Make links from predecessor of filler (from) to successors of filler.
00125              * But if successor is a filler, it has already been eliminated since it
00126              * appears earlier in latnode_list (see build...).  So it can be skipped.
00127              */
00128             for (forlink = node->exits; forlink; forlink = forlink->next) {
00129                 ps_latlink_t *flink = forlink->link;
00130                 if (!dict_filler_word(ps_search_dict(dag->search), flink->to->basewid)) {
00131                     ps_lattice_link(dag, rlink->from, flink->to,
00132                                     score + flink->ascr, flink->ef);
00133                 }
00134             }
00135         }
00136     }
00137 }
00138 
00139 static void
00140 delete_node(ps_lattice_t *dag, ps_latnode_t *node)
00141 {
00142     latlink_list_t *x, *next_x;
00143 
00144     for (x = node->exits; x; x = next_x) {
00145         next_x = x->next;
00146         x->link->from = NULL;
00147         listelem_free(dag->latlink_list_alloc, x);
00148     }
00149     for (x = node->entries; x; x = next_x) {
00150         next_x = x->next;
00151         x->link->to = NULL;
00152         listelem_free(dag->latlink_list_alloc, x);
00153     }
00154     listelem_free(dag->latnode_alloc, node);
00155 }
00156 
00157 static void
00158 remove_dangling_links(ps_lattice_t *dag, ps_latnode_t *node)
00159 {
00160     latlink_list_t *x, *prev_x, *next_x;
00161 
00162     prev_x = NULL;
00163     for (x = node->exits; x; x = next_x) {
00164         next_x = x->next;
00165         if (x->link->to == NULL) {
00166             if (prev_x)
00167                 prev_x->next = next_x;
00168             else
00169                 node->exits = next_x;
00170             listelem_free(dag->latlink_alloc, x->link);
00171             listelem_free(dag->latlink_list_alloc, x);
00172         }
00173         else
00174             prev_x = x;
00175     }
00176     prev_x = NULL;
00177     for (x = node->entries; x; x = next_x) {
00178         next_x = x->next;
00179         if (x->link->from == NULL) {
00180             if (prev_x)
00181                 prev_x->next = next_x;
00182             else
00183                 node->exits = next_x;
00184             listelem_free(dag->latlink_alloc, x->link);
00185             listelem_free(dag->latlink_list_alloc, x);
00186         }
00187         else
00188             prev_x = x;
00189     }
00190 }
00191 
00192 void
00193 ps_lattice_delete_unreachable(ps_lattice_t *dag)
00194 {
00195     ps_latnode_t *node, *prev_node, *next_node;
00196     int i;
00197 
00198     /* Remove unreachable nodes from the list of nodes. */
00199     prev_node = NULL;
00200     for (node = dag->nodes; node; node = next_node) {
00201         next_node = node->next;
00202         if (!node->reachable) {
00203             if (prev_node)
00204                 prev_node->next = next_node;
00205             else
00206                 dag->nodes = next_node;
00207             /* Delete this node and NULLify links to it. */
00208             delete_node(dag, node);
00209         }
00210         else
00211             prev_node = node;
00212     }
00213 
00214     /* Remove all links to and from unreachable nodes. */
00215     i = 0;
00216     for (node = dag->nodes; node; node = node->next) {
00217         /* Assign sequence numbers. */
00218         node->id = i++;
00219 
00220         /* We should obviously not encounter unreachable nodes here! */
00221         assert(node->reachable);
00222 
00223         /* Remove all links that go nowhere. */
00224         remove_dangling_links(dag, node);
00225     }
00226 }
00227 
00228 int32
00229 ps_lattice_write(ps_lattice_t *dag, char const *filename)
00230 {
00231     FILE *fp;
00232     int32 i;
00233     ps_latnode_t *d, *initial, *final;
00234 
00235     initial = dag->start;
00236     final = dag->end;
00237 
00238     E_INFO("Writing lattice file: %s\n", filename);
00239     if ((fp = fopen(filename, "w")) == NULL) {
00240         E_ERROR("fopen(%s,w) failed\n", filename);
00241         return -1;
00242     }
00243 
00244     /* Stupid Sphinx-III lattice code expects 'getcwd:' here */
00245     fprintf(fp, "# getcwd: /this/is/bogus\n");
00246     fprintf(fp, "# -logbase %e\n", logmath_get_base(dag->lmath));
00247     fprintf(fp, "#\n");
00248 
00249     fprintf(fp, "Frames %d\n", dag->n_frames);
00250     fprintf(fp, "#\n");
00251 
00252     for (i = 0, d = dag->nodes; d; d = d->next, i++);
00253     fprintf(fp,
00254             "Nodes %d (NODEID WORD STARTFRAME FIRST-ENDFRAME LAST-ENDFRAME)\n",
00255             i);
00256     for (i = 0, d = dag->nodes; d; d = d->next, i++) {
00257         d->id = i;
00258         fprintf(fp, "%d %s %d %d %d\n",
00259                 i, dict_wordstr(ps_search_dict(dag->search), d->wid),
00260                 d->sf, d->fef, d->lef);
00261     }
00262     fprintf(fp, "#\n");
00263 
00264     fprintf(fp, "Initial %d\nFinal %d\n", initial->id, final->id);
00265     fprintf(fp, "#\n");
00266 
00267     /* Don't bother with this, it's not used by anything. */
00268     fprintf(fp, "BestSegAscr %d (NODEID ENDFRAME ASCORE)\n",
00269             0 /* #BPTable entries */ );
00270     fprintf(fp, "#\n");
00271 
00272     fprintf(fp, "Edges (FROM-NODEID TO-NODEID ASCORE)\n");
00273     for (d = dag->nodes; d; d = d->next) {
00274         latlink_list_t *l;
00275         for (l = d->exits; l; l = l->next)
00276             fprintf(fp, "%d %d %d\n",
00277                     d->id, l->link->to->id, l->link->ascr);
00278     }
00279     fprintf(fp, "End\n");
00280     fclose(fp);
00281 
00282     return 0;
00283 }
00284 
00285 /* Read parameter from a lattice file*/
00286 static int
00287 dag_param_read(lineiter_t *li, char *param)
00288 {
00289     int32 n;
00290 
00291     while ((li = lineiter_next(li)) != NULL) {
00292         char *c;
00293 
00294         /* Ignore comments. */
00295         if (li->buf[0] == '#')
00296             continue;
00297 
00298         /* Find the first space. */
00299         c = strchr(li->buf, ' ');
00300         if (c == NULL) continue;
00301 
00302         /* Check that the first field equals param and that there's a number after it. */
00303         if (strncmp(li->buf, param, strlen(param)) == 0
00304             && sscanf(c + 1, "%d", &n) == 1)
00305             return n;
00306     }
00307     return -1;
00308 }
00309 
00310 /* Mark every node that has a path to the argument dagnode as "reachable". */
00311 static void
00312 dag_mark_reachable(ps_latnode_t * d)
00313 {
00314     latlink_list_t *l;
00315 
00316     d->reachable = 1;
00317     for (l = d->entries; l; l = l->next)
00318         if (l->link->from && !l->link->from->reachable)
00319             dag_mark_reachable(l->link->from);
00320 }
00321 
00322 ps_lattice_t *
00323 ps_lattice_read(ps_decoder_t *ps,
00324                 char const *file)
00325 {
00326     FILE *fp;
00327     int32 ispipe;
00328     lineiter_t *line;
00329     float64 lb;
00330     float32 logratio;
00331     ps_latnode_t *tail;
00332     ps_latnode_t **darray;
00333     ps_lattice_t *dag;
00334     int i, k, n_nodes;
00335     int32 pip, silpen, fillpen;
00336 
00337     dag = ckd_calloc(1, sizeof(*dag));
00338     dag->search = ps->search;
00339     dag->lmath = logmath_retain(ps->lmath);
00340     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00341     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00342     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00343     dag->refcount = 1;
00344 
00345     tail = NULL;
00346     darray = NULL;
00347 
00348     E_INFO("Reading DAG file: %s\n", file);
00349     if ((fp = fopen_compchk(file, &ispipe)) == NULL) {
00350         E_ERROR("fopen_compchk(%s) failed\n", file);
00351         return NULL;
00352     }
00353     line = lineiter_start(fp);
00354 
00355     /* Read and verify logbase (ONE BIG HACK!!) */
00356     if (line == NULL) {
00357         E_ERROR("Premature EOF(%s)\n", file);
00358         goto load_error;
00359     }
00360     if (strncmp(line->buf, "# getcwd: ", 10) != 0) {
00361         E_ERROR("%s does not begin with '# getcwd: '\n%s", file, line->buf);
00362         goto load_error;
00363     }
00364     if ((line = lineiter_next(line)) == NULL) {
00365         E_ERROR("Premature EOF(%s)\n", file);
00366         goto load_error;
00367     }
00368     if ((strncmp(line->buf, "# -logbase ", 11) != 0)
00369         || (sscanf(line->buf + 11, "%lf", &lb) != 1)) {
00370         E_WARN("%s: Cannot find -logbase in header\n", file);
00371         lb = 1.0001;
00372     }
00373     logratio = 1.0f;
00374     if (dag->lmath == NULL)
00375         dag->lmath = logmath_init(lb, 0, TRUE);
00376     else {
00377         float32 pb = logmath_get_base(dag->lmath);
00378         if (fabs(lb - pb) >= 0.0001) {
00379             E_WARN("Inconsistent logbases: %f vs %f: will compensate\n", lb, pb);
00380             logratio = (float32)(log(lb) / log(pb));
00381             E_INFO("Lattice log ratio: %f\n", logratio);
00382         }
00383     }
00384     /* Read Frames parameter */
00385     dag->n_frames = dag_param_read(line, "Frames");
00386     if (dag->n_frames <= 0) {
00387         E_ERROR("Frames parameter missing or invalid\n");
00388         goto load_error;
00389     }
00390     /* Read Nodes parameter */
00391     n_nodes = dag_param_read(line, "Nodes");
00392     if (n_nodes <= 0) {
00393         E_ERROR("Nodes parameter missing or invalid\n");
00394         goto load_error;
00395     }
00396 
00397     /* Read nodes */
00398     darray = ckd_calloc(n_nodes, sizeof(*darray));
00399     for (i = 0; i < n_nodes; i++) {
00400         ps_latnode_t *d;
00401         int32 w;
00402         int seqid, sf, fef, lef;
00403         char wd[256];
00404 
00405         if ((line = lineiter_next(line)) == NULL) {
00406             E_ERROR("Premature EOF while loading Nodes(%s)\n", file);
00407             goto load_error;
00408         }
00409 
00410         if ((k =
00411              sscanf(line->buf, "%d %255s %d %d %d", &seqid, wd, &sf, &fef,
00412                     &lef)) != 5) {
00413             E_ERROR("Cannot parse line: %s, value of count %d\n", line->buf, k);
00414             goto load_error;
00415         }
00416 
00417         w = dict_wordid(ps->dict, wd);
00418         if (w < 0) {
00419             E_ERROR("Unknown word in line: %s\n", line->buf);
00420             goto load_error;
00421         }
00422 
00423         if (seqid != i) {
00424             E_ERROR("Seqno error: %s\n", line->buf);
00425             goto load_error;
00426         }
00427 
00428         d = listelem_malloc(dag->latnode_alloc);
00429         darray[i] = d;
00430         d->wid = w;
00431         d->basewid = dict_basewid(ps->dict, w);
00432         d->id = seqid;
00433         d->sf = sf;
00434         d->fef = fef;
00435         d->lef = lef;
00436         d->reachable = 0;
00437         d->exits = d->entries = NULL;
00438         d->next = NULL;
00439 
00440         if (!dag->nodes)
00441             dag->nodes = d;
00442         else
00443             tail->next = d;
00444         tail = d;
00445     }
00446 
00447     /* Read initial node ID */
00448     k = dag_param_read(line, "Initial");
00449     if ((k < 0) || (k >= n_nodes)) {
00450         E_ERROR("Initial node parameter missing or invalid\n");
00451         goto load_error;
00452     }
00453     dag->start = darray[k];
00454 
00455     /* Read final node ID */
00456     k = dag_param_read(line, "Final");
00457     if ((k < 0) || (k >= n_nodes)) {
00458         E_ERROR("Final node parameter missing or invalid\n");
00459         goto load_error;
00460     }
00461     dag->end = darray[k];
00462 
00463     /* Read bestsegscore entries and ignore them. */
00464     if ((k = dag_param_read(line, "BestSegAscr")) < 0) {
00465         E_ERROR("BestSegAscr parameter missing\n");
00466         goto load_error;
00467     }
00468     for (i = 0; i < k; i++) {
00469         if ((line = lineiter_next(line)) == NULL) {
00470             E_ERROR("Premature EOF while (%s) ignoring BestSegAscr\n",
00471                     line);
00472             goto load_error;
00473         }
00474     }
00475 
00476     /* Read in edges. */
00477     while ((line = lineiter_next(line)) != NULL) {
00478         if (line->buf[0] == '#')
00479             continue;
00480         if (0 == strncmp(line->buf, "Edges", 5))
00481             break;
00482     }
00483     if (line == NULL) {
00484         E_ERROR("Edges missing\n");
00485         goto load_error;
00486     }
00487     while ((line = lineiter_next(line)) != NULL) {
00488         int from, to, ascr;
00489         ps_latnode_t *pd, *d;
00490 
00491         if (sscanf(line->buf, "%d %d %d", &from, &to, &ascr) != 3)
00492             break;
00493         pd = darray[from];
00494         d = darray[to];
00495         if (logratio != 1.0f)
00496             ascr = (int32)(ascr * logratio);
00497         ps_lattice_link(dag, pd, d, ascr, d->sf - 1);
00498     }
00499     if (strcmp(line->buf, "End\n") != 0) {
00500         E_ERROR("Terminating 'End' missing\n");
00501         goto load_error;
00502     }
00503     lineiter_free(line);
00504     fclose_comp(fp, ispipe);
00505     ckd_free(darray);
00506 
00507     /* Minor hack: If the final node is a filler word and not </s>,
00508      * then set its base word ID to </s>, so that the language model
00509      * scores won't be screwed up. */
00510     if (dict_filler_word(ps_search_dict(dag->search), dag->end->wid))
00511         dag->end->basewid = ps_search_finish_wid(dag->search);
00512 
00513     /* Mark reachable from dag->end */
00514     dag_mark_reachable(dag->end);
00515 
00516     /* Free nodes unreachable from dag->end and their links */
00517     ps_lattice_delete_unreachable(dag);
00518 
00519     /* Build links around silence and filler words, since they do not
00520      * exist in the language model. */
00521     pip = logmath_log(dag->lmath, cmd_ln_float32_r(ps->config, "-pip"));
00522     silpen = pip + logmath_log(dag->lmath,
00523                                cmd_ln_float32_r(ps->config, "-silprob"));
00524     fillpen = pip + logmath_log(dag->lmath,
00525                                 cmd_ln_float32_r(ps->config, "-fillprob"));
00526     ps_lattice_bypass_fillers(dag, silpen, fillpen);
00527 
00528     return dag;
00529 
00530   load_error:
00531     E_ERROR("Failed to load %s\n", file);
00532     lineiter_free(line);
00533     if (fp) fclose_comp(fp, ispipe);
00534     ckd_free(darray);
00535     return NULL;
00536 }
00537 
00538 int
00539 ps_lattice_n_frames(ps_lattice_t *dag)
00540 {
00541     return dag->n_frames;
00542 }
00543 
00544 ps_lattice_t *
00545 ps_lattice_init_search(ps_search_t *search, int n_frame)
00546 {
00547     ps_lattice_t *dag;
00548 
00549     dag = ckd_calloc(1, sizeof(*dag));
00550     dag->search = search;
00551     dag->lmath = logmath_retain(search->acmod->lmath);
00552     dag->n_frames = n_frame;
00553     dag->latnode_alloc = listelem_alloc_init(sizeof(ps_latnode_t));
00554     dag->latlink_alloc = listelem_alloc_init(sizeof(ps_latlink_t));
00555     dag->latlink_list_alloc = listelem_alloc_init(sizeof(latlink_list_t));
00556     dag->refcount = 1;
00557     return dag;
00558 }
00559 
00560 ps_lattice_t *
00561 ps_lattice_retain(ps_lattice_t *dag)
00562 {
00563     ++dag->refcount;
00564     return dag;
00565 }
00566 
00567 int
00568 ps_lattice_free(ps_lattice_t *dag)
00569 {
00570     if (dag == NULL)
00571         return 0;
00572     if (--dag->refcount > 0)
00573         return dag->refcount;
00574     logmath_free(dag->lmath);
00575     listelem_alloc_free(dag->latnode_alloc);
00576     listelem_alloc_free(dag->latlink_alloc);
00577     listelem_alloc_free(dag->latlink_list_alloc);
00578     ckd_free(dag->hyp_str);
00579     ckd_free(dag);
00580     return 0;
00581 }
00582 
00583 logmath_t *
00584 ps_lattice_get_logmath(ps_lattice_t *dag)
00585 {
00586     return dag->lmath;
00587 }
00588 
00589 ps_latnode_iter_t *
00590 ps_latnode_iter(ps_lattice_t *dag)
00591 {
00592     return dag->nodes;
00593 }
00594 
00595 ps_latnode_iter_t *
00596 ps_latnode_iter_next(ps_latnode_iter_t *itor)
00597 {
00598     return itor->next;
00599 }
00600 
00601 void
00602 ps_latnode_iter_free(ps_latnode_iter_t *itor)
00603 {
00604     /* Do absolutely nothing. */
00605 }
00606 
00607 ps_latnode_t *
00608 ps_latnode_iter_node(ps_latnode_iter_t *itor)
00609 {
00610     return itor;
00611 }
00612 
00613 int
00614 ps_latnode_times(ps_latnode_t *node, int16 *out_fef, int16 *out_lef)
00615 {
00616     if (out_fef) *out_fef = (int16)node->fef;
00617     if (out_lef) *out_lef = (int16)node->lef;
00618     return node->sf;
00619 }
00620 
00621 char const *
00622 ps_latnode_word(ps_lattice_t *dag, ps_latnode_t *node)
00623 {
00624     return dict_wordstr(ps_search_dict(dag->search), node->wid);
00625 }
00626 
00627 char const *
00628 ps_latnode_baseword(ps_lattice_t *dag, ps_latnode_t *node)
00629 {
00630     return dict_wordstr(ps_search_dict(dag->search), node->basewid);
00631 }
00632 
00633 int32
00634 ps_latnode_prob(ps_lattice_t *dag, ps_latnode_t *node,
00635                 ps_latlink_t **out_link)
00636 {
00637     latlink_list_t *links;
00638     int32 bestpost = logmath_get_zero(dag->lmath);
00639 
00640     for (links = node->exits; links; links = links->next) {
00641         int32 post = links->link->alpha + links->link->beta - dag->norm;
00642         if (post > bestpost) {
00643             if (out_link) *out_link = links->link;
00644             bestpost = post;
00645         }
00646     }
00647     return bestpost;
00648 }
00649 
00650 ps_latlink_iter_t *
00651 ps_latnode_exits(ps_latnode_t *node)
00652 {
00653     return node->exits;
00654 }
00655 
00656 ps_latlink_iter_t *
00657 ps_latnode_entries(ps_latnode_t *node)
00658 {
00659     return node->entries;
00660 }
00661 
00662 ps_latlink_iter_t *
00663 ps_latlink_iter_next(ps_latlink_iter_t *itor)
00664 {
00665     return itor->next;
00666 }
00667 
00668 void
00669 ps_latlink_iter_free(ps_latlink_iter_t *itor)
00670 {
00671     /* Do absolutely nothing. */
00672 }
00673 
00674 ps_latlink_t *
00675 ps_latlink_iter_link(ps_latlink_iter_t *itor)
00676 {
00677     return itor->link;
00678 }
00679 
00680 int
00681 ps_latlink_times(ps_latlink_t *link, int16 *out_sf)
00682 {
00683     if (out_sf) {
00684         if (link->from) {
00685             *out_sf = link->from->sf;
00686         }
00687         else {
00688             *out_sf = 0;
00689         }
00690     }
00691     return link->ef;
00692 }
00693 
00694 ps_latnode_t *
00695 ps_latlink_nodes(ps_latlink_t *link, ps_latnode_t **out_src)
00696 {
00697     if (out_src) *out_src = link->from;
00698     return link->to;
00699 }
00700 
00701 char const *
00702 ps_latlink_word(ps_lattice_t *dag, ps_latlink_t *link)
00703 {
00704     if (link->from == NULL)
00705         return NULL;
00706     return dict_wordstr(ps_search_dict(dag->search), link->from->wid);
00707 }
00708 
00709 char const *
00710 ps_latlink_baseword(ps_lattice_t *dag, ps_latlink_t *link)
00711 {
00712     if (link->from == NULL)
00713         return NULL;
00714     return dict_wordstr(ps_search_dict(dag->search), link->from->basewid);
00715 }
00716 
00717 ps_latlink_t *
00718 ps_latlink_pred(ps_latlink_t *link)
00719 {
00720     return link->best_prev;
00721 }
00722 
00723 int32
00724 ps_latlink_prob(ps_lattice_t *dag, ps_latlink_t *link, int32 *out_ascr)
00725 {
00726     int32 post = link->alpha + link->beta - dag->norm;
00727     if (out_ascr) *out_ascr = link->ascr;
00728     return post;
00729 }
00730 
00731 char const *
00732 ps_lattice_hyp(ps_lattice_t *dag, ps_latlink_t *link)
00733 {
00734     ps_latlink_t *l;
00735     size_t len;
00736     char *c;
00737 
00738     /* Backtrace once to get hypothesis length. */
00739     len = 0;
00740     /* FIXME: There may not be a search, but actually there should be a dict. */
00741     if (dict_real_word(ps_search_dict(dag->search), link->to->basewid))
00742         len += strlen(dict_wordstr(ps_search_dict(dag->search), link->to->basewid)) + 1;
00743     for (l = link; l; l = l->best_prev) {
00744         if (dict_real_word(ps_search_dict(dag->search), l->from->basewid))
00745             len += strlen(dict_wordstr(ps_search_dict(dag->search), l->from->basewid)) + 1;
00746     }
00747 
00748     /* Backtrace again to construct hypothesis string. */
00749     ckd_free(dag->hyp_str);
00750     dag->hyp_str = ckd_calloc(1, len+1); /* extra one incase the hyp is empty */
00751     c = dag->hyp_str + len - 1;
00752     if (dict_real_word(ps_search_dict(dag->search), link->to->basewid)) {
00753         len = strlen(dict_wordstr(ps_search_dict(dag->search), link->to->basewid));
00754         c -= len;
00755         memcpy(c, dict_wordstr(ps_search_dict(dag->search), link->to->basewid), len);
00756         if (c > dag->hyp_str) {
00757             --c;
00758             *c = ' ';
00759         }
00760     }
00761     for (l = link; l; l = l->best_prev) {
00762         if (dict_real_word(ps_search_dict(dag->search), l->from->basewid)) {
00763             len = strlen(dict_wordstr(ps_search_dict(dag->search), l->from->basewid));
00764             c -= len;
00765             memcpy(c, dict_wordstr(ps_search_dict(dag->search), l->from->basewid), len);
00766             if (c > dag->hyp_str) {
00767                 --c;
00768                 *c = ' ';
00769             }
00770         }
00771     }
00772 
00773     return dag->hyp_str;
00774 }
00775 
00776 static void
00777 ps_lattice_compute_lscr(ps_seg_t *seg, ps_latlink_t *link, int to)
00778 {
00779     ngram_model_t *lmset;
00780 
00781     /* Language model score is included in the link score for FSG
00782      * search.  FIXME: Of course, this is sort of a hack :( */
00783     if (0 != strcmp(ps_search_name(seg->search), "ngram")) {
00784         seg->lback = 1; /* Unigram... */
00785         seg->lscr = 0;
00786         return;
00787     }
00788         
00789     lmset = ((ngram_search_t *)seg->search)->lmset;
00790 
00791     if (link->best_prev == NULL) {
00792         if (to) /* Sentence has only two words. */
00793             seg->lscr = ngram_bg_score(lmset, link->to->basewid,
00794                                        link->from->basewid, &seg->lback);
00795         else {/* This is the start symbol, its lscr is always 0. */
00796             seg->lscr = 0;
00797             seg->lback = 1;
00798         }
00799     }
00800     else {
00801         /* Find the two predecessor words. */
00802         if (to) {
00803             seg->lscr = ngram_tg_score(lmset, link->to->basewid,
00804                                        link->from->basewid,
00805                                        link->best_prev->from->basewid,
00806                                        &seg->lback);
00807         }
00808         else {
00809             if (link->best_prev->best_prev)
00810                 seg->lscr = ngram_tg_score(lmset, link->from->basewid,
00811                                            link->best_prev->from->basewid,
00812                                            link->best_prev->best_prev->from->basewid,
00813                                            &seg->lback);
00814             else
00815                 seg->lscr = ngram_bg_score(lmset, link->from->basewid,
00816                                            link->best_prev->from->basewid,
00817                                            &seg->lback);
00818         }
00819     }
00820 }
00821 
00822 static void
00823 ps_lattice_link2itor(ps_seg_t *seg, ps_latlink_t *link, int to)
00824 {
00825     dag_seg_t *itor = (dag_seg_t *)seg;
00826     ps_latnode_t *node;
00827 
00828     if (to) {
00829         node = link->to;
00830         seg->ef = node->lef;
00831         seg->prob = 0; /* norm + beta - norm */
00832     }
00833     else {
00834         latlink_list_t *x;
00835         ps_latnode_t *n;
00836         logmath_t *lmath = ps_search_acmod(seg->search)->lmath;
00837 
00838         node = link->from;
00839         seg->ef = link->ef;
00840         seg->prob = link->alpha + link->beta - itor->norm;
00841         /* Sum over all exits for this word and any alternate
00842            pronunciations at the same frame. */
00843         for (n = node; n; n = n->alt) {
00844             for (x = n->exits; x; x = x->next) {
00845                 if (x->link == link)
00846                     continue;
00847                 seg->prob = logmath_add(lmath, seg->prob,
00848                                         x->link->alpha + x->link->beta - itor->norm);
00849             }
00850         }
00851     }
00852     seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid);
00853     seg->sf = node->sf;
00854     seg->ascr = link->ascr;
00855     /* Compute language model score from best predecessors. */
00856     ps_lattice_compute_lscr(seg, link, to);
00857 }
00858 
00859 static void
00860 ps_lattice_seg_free(ps_seg_t *seg)
00861 {
00862     dag_seg_t *itor = (dag_seg_t *)seg;
00863     
00864     ckd_free(itor->links);
00865     ckd_free(itor);
00866 }
00867 
00868 static ps_seg_t *
00869 ps_lattice_seg_next(ps_seg_t *seg)
00870 {
00871     dag_seg_t *itor = (dag_seg_t *)seg;
00872 
00873     ++itor->cur;
00874     if (itor->cur == itor->n_links + 1) {
00875         ps_lattice_seg_free(seg);
00876         return NULL;
00877     }
00878     else if (itor->cur == itor->n_links) {
00879         /* Re-use the last link but with the "to" node. */
00880         ps_lattice_link2itor(seg, itor->links[itor->cur - 1], TRUE);
00881     }
00882     else {
00883         ps_lattice_link2itor(seg, itor->links[itor->cur], FALSE);
00884     }
00885 
00886     return seg;
00887 }
00888 
00889 static ps_segfuncs_t ps_lattice_segfuncs = {
00890     /* seg_next */ ps_lattice_seg_next,
00891     /* seg_free */ ps_lattice_seg_free
00892 };
00893 
00894 ps_seg_t *
00895 ps_lattice_seg_iter(ps_lattice_t *dag, ps_latlink_t *link, float32 lwf)
00896 {
00897     dag_seg_t *itor;
00898     ps_latlink_t *l;
00899     int cur;
00900 
00901     /* Calling this an "iterator" is a bit of a misnomer since we have
00902      * to get the entire backtrace in order to produce it.
00903      */
00904     itor = ckd_calloc(1, sizeof(*itor));
00905     itor->base.vt = &ps_lattice_segfuncs;
00906     itor->base.search = dag->search;
00907     itor->base.lwf = lwf;
00908     itor->n_links = 0;
00909     itor->norm = dag->norm;
00910 
00911     for (l = link; l; l = l->best_prev) {
00912         ++itor->n_links;
00913     }
00914     if (itor->n_links == 0) {
00915         ckd_free(itor);
00916         return NULL;
00917     }
00918 
00919     itor->links = ckd_calloc(itor->n_links, sizeof(*itor->links));
00920     cur = itor->n_links - 1;
00921     for (l = link; l; l = l->best_prev) {
00922         itor->links[cur] = l;
00923         --cur;
00924     }
00925 
00926     ps_lattice_link2itor((ps_seg_t *)itor, itor->links[0], FALSE);
00927     return (ps_seg_t *)itor;
00928 }
00929 
00930 latlink_list_t *
00931 latlink_list_new(ps_lattice_t *dag, ps_latlink_t *link, latlink_list_t *next)
00932 {
00933     latlink_list_t *ll;
00934 
00935     ll = listelem_malloc(dag->latlink_list_alloc);
00936     ll->link = link;
00937     ll->next = next;
00938 
00939     return ll;
00940 }
00941 
00942 void
00943 ps_lattice_pushq(ps_lattice_t *dag, ps_latlink_t *link)
00944 {
00945     if (dag->q_head == NULL)
00946         dag->q_head = dag->q_tail = latlink_list_new(dag, link, NULL);
00947     else {
00948         dag->q_tail->next = latlink_list_new(dag, link, NULL);
00949         dag->q_tail = dag->q_tail->next;
00950     }
00951 
00952 }
00953 
00954 ps_latlink_t *
00955 ps_lattice_popq(ps_lattice_t *dag)
00956 {
00957     latlink_list_t *x;
00958     ps_latlink_t *link;
00959 
00960     if (dag->q_head == NULL)
00961         return NULL;
00962     link = dag->q_head->link;
00963     x = dag->q_head->next;
00964     listelem_free(dag->latlink_list_alloc, dag->q_head);
00965     dag->q_head = x;
00966     if (dag->q_head == NULL)
00967         dag->q_tail = NULL;
00968     return link;
00969 }
00970 
00971 void
00972 ps_lattice_delq(ps_lattice_t *dag)
00973 {
00974     while (ps_lattice_popq(dag)) {
00975         /* Do nothing. */
00976     }
00977 }
00978 
00979 ps_latlink_t *
00980 ps_lattice_traverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
00981 {
00982     ps_latnode_t *node;
00983     latlink_list_t *x;
00984 
00985     /* Cancel any unfinished traversal. */
00986     ps_lattice_delq(dag);
00987 
00988     /* Initialize node fanin counts and path scores. */
00989     for (node = dag->nodes; node; node = node->next)
00990         node->info.fanin = 0;
00991     for (node = dag->nodes; node; node = node->next) {
00992         for (x = node->exits; x; x = x->next)
00993             (x->link->to->info.fanin)++;
00994     }
00995 
00996     /* Initialize agenda with all exits from start. */
00997     if (start == NULL) start = dag->start;
00998     for (x = start->exits; x; x = x->next)
00999         ps_lattice_pushq(dag, x->link);
01000 
01001     /* Pull the first edge off the queue. */
01002     return ps_lattice_traverse_next(dag, end);
01003 }
01004 
01005 ps_latlink_t *
01006 ps_lattice_traverse_next(ps_lattice_t *dag, ps_latnode_t *end)
01007 {
01008     ps_latlink_t *next;
01009 
01010     next = ps_lattice_popq(dag);
01011     if (next == NULL)
01012         return NULL;
01013 
01014     /* Decrease fanin count for destination node and expand outgoing
01015      * edges if all incoming edges have been seen. */
01016     --next->to->info.fanin;
01017     if (next->to->info.fanin == 0) {
01018         latlink_list_t *x;
01019 
01020         if (end == NULL) end = dag->end;
01021         if (next->to == end) {
01022             /* If we have traversed all links entering the end node,
01023              * clear the queue, causing future calls to this function
01024              * to return NULL. */
01025             ps_lattice_delq(dag);
01026             return next;
01027         }
01028 
01029         /* Extend all outgoing edges. */
01030         for (x = next->to->exits; x; x = x->next)
01031             ps_lattice_pushq(dag, x->link);
01032     }
01033     return next;
01034 }
01035 
01036 ps_latlink_t *
01037 ps_lattice_reverse_edges(ps_lattice_t *dag, ps_latnode_t *start, ps_latnode_t *end)
01038 {
01039     ps_latnode_t *node;
01040     latlink_list_t *x;
01041 
01042     /* Cancel any unfinished traversal. */
01043     ps_lattice_delq(dag);
01044 
01045     /* Initialize node fanout counts and path scores. */
01046     for (node = dag->nodes; node; node = node->next) {
01047         node->info.fanin = 0;
01048         for (x = node->exits; x; x = x->next)
01049             ++node->info.fanin;
01050     }
01051 
01052     /* Initialize agenda with all entries from end. */
01053     if (end == NULL) end = dag->end;
01054     for (x = end->entries; x; x = x->next)
01055         ps_lattice_pushq(dag, x->link);
01056 
01057     /* Pull the first edge off the queue. */
01058     return ps_lattice_reverse_next(dag, start);
01059 }
01060 
01061 ps_latlink_t *
01062 ps_lattice_reverse_next(ps_lattice_t *dag, ps_latnode_t *start)
01063 {
01064     ps_latlink_t *next;
01065 
01066     next = ps_lattice_popq(dag);
01067     if (next == NULL)
01068         return NULL;
01069 
01070     /* Decrease fanout count for source node and expand incoming
01071      * edges if all incoming edges have been seen. */
01072     --next->from->info.fanin;
01073     if (next->from->info.fanin == 0) {
01074         latlink_list_t *x;
01075 
01076         if (start == NULL) start = dag->start;
01077         if (next->from == start) {
01078             /* If we have traversed all links entering the start node,
01079              * clear the queue, causing future calls to this function
01080              * to return NULL. */
01081             ps_lattice_delq(dag);
01082             return next;
01083         }
01084 
01085         /* Extend all outgoing edges. */
01086         for (x = next->from->entries; x; x = x->next)
01087             ps_lattice_pushq(dag, x->link);
01088     }
01089     return next;
01090 }
01091 
01092 /*
01093  * Find the best score from dag->start to end point of any link and
01094  * use it to update links further down the path.  This is like
01095  * single-source shortest path search, except that it is done over
01096  * edges rather than nodes, which allows us to do exact trigram scoring.
01097  *
01098  * Helpfully enough, we get half of the posterior probability
01099  * calculation for free that way too.  (interesting research topic: is
01100  * there a reliable Viterbi analogue to word-level Forward-Backward
01101  * like there is for state-level?  Or, is it just lattice density?)
01102  */
01103 ps_latlink_t *
01104 ps_lattice_bestpath(ps_lattice_t *dag, ngram_model_t *lmset,
01105                     float32 lwf, float32 ascale)
01106 {
01107     ps_search_t *search;
01108     ps_latnode_t *node;
01109     ps_latlink_t *link;
01110     ps_latlink_t *bestend;
01111     latlink_list_t *x;
01112     logmath_t *lmath;
01113     int32 bestescr;
01114 
01115     search = dag->search;
01116     lmath = dag->lmath;
01117 
01118     /* Initialize path scores for all links exiting dag->start, and
01119      * set all other scores to the minimum.  Also initialize alphas to
01120      * log-zero. */
01121     for (node = dag->nodes; node; node = node->next) {
01122         for (x = node->exits; x; x = x->next) {
01123             x->link->path_scr = MAX_NEG_INT32;
01124             x->link->alpha = logmath_get_zero(lmath);
01125         }
01126     }
01127     for (x = dag->start->exits; x; x = x->next) {
01128         int32 n_used;
01129 
01130         /* Ignore filler words. */
01131         if (dict_filler_word(ps_search_dict(search), x->link->to->basewid)
01132             && x->link->to != dag->end)
01133             continue;
01134 
01135         /* Best path points to dag->start, obviously. */
01136         if (lmset)
01137             x->link->path_scr = x->link->ascr +
01138                 ngram_bg_score(lmset, x->link->to->basewid,
01139                                ps_search_start_wid(search), &n_used) * lwf;
01140         else
01141             x->link->path_scr = x->link->ascr;
01142         x->link->best_prev = NULL;
01143         /* No predecessors for start links. */
01144         x->link->alpha = 0;
01145     }
01146 
01147     /* Traverse the edges in the graph, updating path scores. */
01148     for (link = ps_lattice_traverse_edges(dag, NULL, NULL);
01149          link; link = ps_lattice_traverse_next(dag, NULL)) {
01150         int32 bprob, n_used;
01151 
01152         /* Skip filler nodes in traversal. */
01153         if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start)
01154             continue;
01155         if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end)
01156             continue;
01157 
01158         /* Sanity check, we should not be traversing edges that
01159          * weren't previously updated, otherwise nasty overflows will result. */
01160         assert(link->path_scr != MAX_NEG_INT32);
01161 
01162         /* Calculate common bigram probability for all alphas. */
01163         if (lmset)
01164             bprob = ngram_ng_prob(lmset,
01165                                   link->to->basewid,
01166                                   &link->from->basewid, 1, &n_used);
01167         else
01168             bprob = 0;
01169         /* Add in this link's acoustic score, which was a constant
01170            factor in previous computations (if any). */
01171         link->alpha += link->ascr * ascale;
01172 
01173         /* Update scores for all paths exiting link->to. */
01174         for (x = link->to->exits; x; x = x->next) {
01175             int32 tscore, score;
01176 
01177             /* Skip links to filler words in update. */
01178             if (dict_filler_word(ps_search_dict(search), x->link->to->basewid)
01179                 && x->link->to != dag->end)
01180                 continue;
01181 
01182             /* Update alpha with sum of previous alphas. */
01183             x->link->alpha = logmath_add(lmath, x->link->alpha, link->alpha + bprob);
01184             /* Calculate trigram score for bestpath. */
01185             if (lmset)
01186                 tscore = ngram_tg_score(lmset, x->link->to->basewid,
01187                                         link->to->basewid,
01188                                         link->from->basewid, &n_used) * lwf;
01189             else
01190                 tscore = 0;
01191             /* Update link score with maximum link score. */
01192             score = link->path_scr + tscore + x->link->ascr;
01193             if (score BETTER_THAN x->link->path_scr) {
01194                 x->link->path_scr = score;
01195                 x->link->best_prev = link;
01196             }
01197         }
01198     }
01199 
01200     /* Find best link entering final node, and calculate normalizer
01201      * for posterior probabilities. */
01202     bestend = NULL;
01203     bestescr = MAX_NEG_INT32;
01204 
01205     /* Normalizer is the alpha for the imaginary link exiting the
01206        final node. */
01207     dag->norm = logmath_get_zero(lmath);
01208     for (x = dag->end->entries; x; x = x->next) {
01209         int32 bprob, n_used;
01210 
01211         if (dict_filler_word(ps_search_dict(search), x->link->from->basewid))
01212             continue;
01213         if (lmset)
01214             bprob = ngram_ng_prob(lmset,
01215                                   x->link->to->basewid,
01216                                   &x->link->from->basewid, 1, &n_used);
01217         else
01218             bprob = 0;
01219         dag->norm = logmath_add(lmath, dag->norm, x->link->alpha + bprob);
01220         if (x->link->path_scr BETTER_THAN bestescr) {
01221             bestescr = x->link->path_scr;
01222             bestend = x->link;
01223         }
01224     }
01225     /* FIXME: floating point... */
01226     dag->norm += (int32)dag->final_node_ascr * ascale;
01227 
01228     E_INFO("Normalizer P(O) = alpha(%s:%d:%d) = %d\n",
01229            dict_wordstr(dag->search->dict, dag->end->wid),
01230            dag->end->sf, dag->end->lef,
01231            dag->norm);
01232     return bestend;
01233 }
01234 
01235 static int32
01236 ps_lattice_joint(ps_lattice_t *dag, ps_latlink_t *link, float32 ascale)
01237 {
01238     ngram_model_t *lmset;
01239     int32 jprob;
01240 
01241     /* Sort of a hack... */
01242     if (dag->search && 0 == strcmp(ps_search_name(dag->search), "ngram"))
01243         lmset = ((ngram_search_t *)dag->search)->lmset;
01244     else
01245         lmset = NULL;
01246 
01247     jprob = dag->final_node_ascr * ascale;
01248     while (link) {
01249         if (lmset) {
01250             int lback;
01251             /* Compute unscaled language model probability.  Note that
01252                this is actually not the language model probability
01253                that corresponds to this link, but that is okay,
01254                because we are just taking the sum over all links in
01255                the best path. */
01256             jprob += ngram_ng_prob(lmset, link->to->basewid,
01257                                    &link->from->basewid, 1, &lback);
01258         }
01259         /* If there is no language model, we assume that the language
01260            model probability (such as it is) has been included in the
01261            link score. */
01262         jprob += link->ascr * ascale;
01263         link = link->best_prev;
01264     }
01265 
01266     E_INFO("Joint P(O,S) = %d P(S|O) = %d\n", jprob, jprob - dag->norm);
01267     return jprob;
01268 }
01269 
01270 int32
01271 ps_lattice_posterior(ps_lattice_t *dag, ngram_model_t *lmset,
01272                      float32 ascale)
01273 {
01274     ps_search_t *search;
01275     logmath_t *lmath;
01276     ps_latnode_t *node;
01277     ps_latlink_t *link;
01278     latlink_list_t *x;
01279     ps_latlink_t *bestend;
01280     int32 bestescr;
01281 
01282     search = dag->search;
01283     lmath = dag->lmath;
01284 
01285     /* Reset all betas to zero. */
01286     for (node = dag->nodes; node; node = node->next) {
01287         for (x = node->exits; x; x = x->next) {
01288             x->link->beta = logmath_get_zero(lmath);
01289         }
01290     }
01291 
01292     bestend = NULL;
01293     bestescr = MAX_NEG_INT32;
01294     /* Accumulate backward probabilities for all links. */
01295     for (link = ps_lattice_reverse_edges(dag, NULL, NULL);
01296          link; link = ps_lattice_reverse_next(dag, NULL)) {
01297         int32 bprob, n_used;
01298 
01299         /* Skip filler nodes in traversal. */
01300         if (dict_filler_word(ps_search_dict(search), link->from->basewid) && link->from != dag->start)
01301             continue;
01302         if (dict_filler_word(ps_search_dict(search), link->to->basewid) && link->to != dag->end)
01303             continue;
01304 
01305         /* Calculate LM probability. */
01306         if (lmset)
01307             bprob = ngram_ng_prob(lmset, link->to->basewid,
01308                                   &link->from->basewid, 1, &n_used);
01309         else
01310             bprob = 0;
01311 
01312         if (link->to == dag->end) {
01313             /* Track the best path - we will backtrace in order to
01314                calculate the unscaled joint probability for sentence
01315                posterior. */
01316             if (link->path_scr BETTER_THAN bestescr) {
01317                 bestescr = link->path_scr;
01318                 bestend = link;
01319             }
01320             /* Imaginary exit link from final node has beta = 1.0 */
01321             link->beta = bprob + dag->final_node_ascr * ascale;
01322         }
01323         else {
01324             /* Update beta from all outgoing betas. */
01325             for (x = link->to->exits; x; x = x->next) {
01326                 if (dict_filler_word(ps_search_dict(search), x->link->to->basewid) && x->link->to != dag->end)
01327                     continue;
01328                 link->beta = logmath_add(lmath, link->beta,
01329                                          x->link->beta + bprob + x->link->ascr * ascale);
01330             }
01331         }
01332     }
01333 
01334     /* Return P(S|O) = P(O,S)/P(O) */
01335     return ps_lattice_joint(dag, bestend, ascale) - dag->norm;
01336 }
01337 
01338 
01339 /* Parameters to prune n-best alternatives search */
01340 #define MAX_PATHS       500     /* Max allowed active paths at any time */
01341 #define MAX_HYP_TRIES   10000
01342 
01343 /*
01344  * For each node in any path between from and end of utt, find the
01345  * best score from "from".sf to end of utt.  (NOTE: Uses bigram probs;
01346  * this is an estimate of the best score from "from".)  (NOTE #2: yes,
01347  * this is the "heuristic score" used in A* search)
01348  */
01349 static int32
01350 best_rem_score(ps_astar_t *nbest, ps_latnode_t * from)
01351 {
01352     ps_lattice_t *dag;
01353     latlink_list_t *x;
01354     int32 bestscore, score;
01355 
01356     dag = nbest->dag;
01357     if (from->info.rem_score <= 0)
01358         return (from->info.rem_score);
01359 
01360     /* Best score from "from" to end of utt not known; compute from successors */
01361     bestscore = WORST_SCORE;
01362     for (x = from->exits; x; x = x->next) {
01363         int32 n_used;
01364 
01365         score = best_rem_score(nbest, x->link->to);
01366         score += x->link->ascr;
01367         if (nbest->lmset)
01368             score += ngram_bg_score(nbest->lmset, x->link->to->basewid,
01369                                     from->basewid, &n_used) * nbest->lwf;
01370         if (score BETTER_THAN bestscore)
01371             bestscore = score;
01372     }
01373     from->info.rem_score = bestscore;
01374 
01375     return bestscore;
01376 }
01377 
01378 /*
01379  * Insert newpath in sorted (by path score) list of paths.  But if newpath is
01380  * too far down the list, drop it (FIXME: necessary?)
01381  * total_score = path score (newpath) + rem_score to end of utt.
01382  */
01383 static void
01384 path_insert(ps_astar_t *nbest, ps_latpath_t *newpath, int32 total_score)
01385 {
01386     ps_lattice_t *dag;
01387     ps_latpath_t *prev, *p;
01388     int32 i;
01389 
01390     dag = nbest->dag;
01391     prev = NULL;
01392     for (i = 0, p = nbest->path_list; (i < MAX_PATHS) && p; p = p->next, i++) {
01393         if ((p->score + p->node->info.rem_score) < total_score)
01394             break;
01395         prev = p;
01396     }
01397 
01398     /* newpath should be inserted between prev and p */
01399     if (i < MAX_PATHS) {
01400         /* Insert new partial hyp */
01401         newpath->next = p;
01402         if (!prev)
01403             nbest->path_list = newpath;
01404         else
01405             prev->next = newpath;
01406         if (!p)
01407             nbest->path_tail = newpath;
01408 
01409         nbest->n_path++;
01410         nbest->n_hyp_insert++;
01411         nbest->insert_depth += i;
01412     }
01413     else {
01414         /* newpath score too low; reject it and also prune paths beyond MAX_PATHS */
01415         nbest->path_tail = prev;
01416         prev->next = NULL;
01417         nbest->n_path = MAX_PATHS;
01418         listelem_free(nbest->latpath_alloc, newpath);
01419 
01420         nbest->n_hyp_reject++;
01421         for (; p; p = newpath) {
01422             newpath = p->next;
01423             listelem_free(nbest->latpath_alloc, p);
01424             nbest->n_hyp_reject++;
01425         }
01426     }
01427 }
01428 
01429 /* Find all possible extensions to given partial path */
01430 static void
01431 path_extend(ps_astar_t *nbest, ps_latpath_t * path)
01432 {
01433     latlink_list_t *x;
01434     ps_latpath_t *newpath;
01435     int32 total_score, tail_score;
01436     ps_lattice_t *dag;
01437 
01438     dag = nbest->dag;
01439 
01440     /* Consider all successors of path->node */
01441     for (x = path->node->exits; x; x = x->next) {
01442         int32 n_used;
01443 
01444         /* Skip successor if no path from it reaches the final node */
01445         if (x->link->to->info.rem_score <= WORST_SCORE)
01446             continue;
01447 
01448         /* Create path extension and compute exact score for this extension */
01449         newpath = listelem_malloc(nbest->latpath_alloc);
01450         newpath->node = x->link->to;
01451         newpath->parent = path;
01452         newpath->score = path->score + x->link->ascr;
01453         if (nbest->lmset) {
01454             if (path->parent) {
01455                 newpath->score += nbest->lwf
01456                     * ngram_tg_score(nbest->lmset, newpath->node->basewid,
01457                                      path->node->basewid,
01458                                      path->parent->node->basewid, &n_used);
01459             }
01460             else 
01461                 newpath->score += nbest->lwf
01462                     * ngram_bg_score(nbest->lmset, newpath->node->basewid,
01463                                      path->node->basewid, &n_used);
01464         }
01465 
01466         /* Insert new partial path hypothesis into sorted path_list */
01467         nbest->n_hyp_tried++;
01468         total_score = newpath->score + newpath->node->info.rem_score;
01469 
01470         /* First see if hyp would be worse than the worst */
01471         if (nbest->n_path >= MAX_PATHS) {
01472             tail_score =
01473                 nbest->path_tail->score
01474                 + nbest->path_tail->node->info.rem_score;
01475             if (total_score < tail_score) {
01476                 listelem_free(nbest->latpath_alloc, newpath);
01477                 nbest->n_hyp_reject++;
01478                 continue;
01479             }
01480         }
01481 
01482         path_insert(nbest, newpath, total_score);
01483     }
01484 }
01485 
01486 ps_astar_t *
01487 ps_astar_start(ps_lattice_t *dag,
01488                   ngram_model_t *lmset,
01489                   float32 lwf,
01490                   int sf, int ef,
01491                   int w1, int w2)
01492 {
01493     ps_astar_t *nbest;
01494     ps_latnode_t *node;
01495 
01496     nbest = ckd_calloc(1, sizeof(*nbest));
01497     nbest->dag = dag;
01498     nbest->lmset = lmset;
01499     nbest->lwf = lwf;
01500     nbest->sf = sf;
01501     if (ef < 0)
01502         nbest->ef = dag->n_frames + 1;
01503     else
01504         nbest->ef = ef;
01505     nbest->w1 = w1;
01506     nbest->w2 = w2;
01507     nbest->latpath_alloc = listelem_alloc_init(sizeof(ps_latpath_t));
01508 
01509     /* Initialize rem_score (A* heuristic) to default values */
01510     for (node = dag->nodes; node; node = node->next) {
01511         if (node == dag->end)
01512             node->info.rem_score = 0;
01513         else if (node->exits == NULL)
01514             node->info.rem_score = WORST_SCORE;
01515         else
01516             node->info.rem_score = 1;   /* +ve => unknown value */
01517     }
01518 
01519     /* Create initial partial hypotheses list consisting of nodes starting at sf */
01520     nbest->path_list = nbest->path_tail = NULL;
01521     for (node = dag->nodes; node; node = node->next) {
01522         if (node->sf == sf) {
01523             ps_latpath_t *path;
01524             int32 n_used;
01525 
01526             best_rem_score(nbest, node);
01527             path = listelem_malloc(nbest->latpath_alloc);
01528             path->node = node;
01529             path->parent = NULL;
01530             if (nbest->lmset)
01531                 path->score = nbest->lwf *
01532                     (w1 < 0)
01533                     ? ngram_bg_score(nbest->lmset, node->basewid, w2, &n_used)
01534                     : ngram_tg_score(nbest->lmset, node->basewid, w2, w1, &n_used);
01535             else
01536                 path->score = 0;
01537             path_insert(nbest, path, path->score + node->info.rem_score);
01538         }
01539     }
01540 
01541     return nbest;
01542 }
01543 
01544 ps_latpath_t *
01545 ps_astar_next(ps_astar_t *nbest)
01546 {
01547     ps_lattice_t *dag;
01548 
01549     dag = nbest->dag;
01550 
01551     /* Pop the top (best) partial hypothesis */
01552     while ((nbest->top = nbest->path_list) != NULL) {
01553         nbest->path_list = nbest->path_list->next;
01554         if (nbest->top == nbest->path_tail)
01555             nbest->path_tail = NULL;
01556         nbest->n_path--;
01557 
01558         /* Complete hypothesis? */
01559         if ((nbest->top->node->sf >= nbest->ef)
01560             || ((nbest->top->node == dag->end) &&
01561                 (nbest->ef > dag->end->sf))) {
01562             /* FIXME: Verify that it is non-empty.  Also we may want
01563              * to verify that it is actually distinct from other
01564              * paths, since often this is not the case*/
01565             return nbest->top;
01566         }
01567         else {
01568             if (nbest->top->node->fef < nbest->ef)
01569                 path_extend(nbest, nbest->top);
01570         }
01571     }
01572 
01573     /* Did not find any more paths to extend. */
01574     return NULL;
01575 }
01576 
01577 char const *
01578 ps_astar_hyp(ps_astar_t *nbest, ps_latpath_t *path)
01579 {
01580     ps_search_t *search;
01581     ps_latpath_t *p;
01582     size_t len;
01583     char *c;
01584     char *hyp;
01585 
01586     search = nbest->dag->search;
01587 
01588     /* Backtrace once to get hypothesis length. */
01589     len = 0;
01590     for (p = path; p; p = p->parent) {
01591         if (dict_real_word(ps_search_dict(search), p->node->basewid))
01592             len += strlen(dict_wordstr(ps_search_dict(search), p->node->basewid)) + 1;
01593     }
01594 
01595     if (len == 0) {
01596         return NULL;
01597     }
01598 
01599     /* Backtrace again to construct hypothesis string. */
01600     hyp = ckd_calloc(1, len);
01601     c = hyp + len - 1;
01602     for (p = path; p; p = p->parent) {
01603         if (dict_real_word(ps_search_dict(search), p->node->basewid)) {
01604             len = strlen(dict_wordstr(ps_search_dict(search), p->node->basewid));
01605             c -= len;
01606             memcpy(c, dict_wordstr(ps_search_dict(search), p->node->basewid), len);
01607             if (c > hyp) {
01608                 --c;
01609                 *c = ' ';
01610             }
01611         }
01612     }
01613 
01614     nbest->hyps = glist_add_ptr(nbest->hyps, hyp);
01615     return hyp;
01616 }
01617 
01618 static void
01619 ps_astar_node2itor(astar_seg_t *itor)
01620 {
01621     ps_seg_t *seg = (ps_seg_t *)itor;
01622     ps_latnode_t *node;
01623 
01624     assert(itor->cur < itor->n_nodes);
01625     node = itor->nodes[itor->cur];
01626     if (itor->cur == itor->n_nodes - 1)
01627         seg->ef = node->lef;
01628     else
01629         seg->ef = itor->nodes[itor->cur + 1]->sf - 1;
01630     seg->word = dict_wordstr(ps_search_dict(seg->search), node->wid);
01631     seg->sf = node->sf;
01632     seg->prob = 0; /* FIXME: implement forward-backward */
01633 }
01634 
01635 static void
01636 ps_astar_seg_free(ps_seg_t *seg)
01637 {
01638     astar_seg_t *itor = (astar_seg_t *)seg;
01639     ckd_free(itor->nodes);
01640     ckd_free(itor);
01641 }
01642 
01643 static ps_seg_t *
01644 ps_astar_seg_next(ps_seg_t *seg)
01645 {
01646     astar_seg_t *itor = (astar_seg_t *)seg;
01647 
01648     ++itor->cur;
01649     if (itor->cur == itor->n_nodes) {
01650         ps_astar_seg_free(seg);
01651         return NULL;
01652     }
01653     else {
01654         ps_astar_node2itor(itor);
01655     }
01656 
01657     return seg;
01658 }
01659 
01660 static ps_segfuncs_t ps_astar_segfuncs = {
01661     /* seg_next */ ps_astar_seg_next,
01662     /* seg_free */ ps_astar_seg_free
01663 };
01664 
01665 ps_seg_t *
01666 ps_astar_seg_iter(ps_astar_t *astar, ps_latpath_t *path, float32 lwf)
01667 {
01668     astar_seg_t *itor;
01669     ps_latpath_t *p;
01670     int cur;
01671 
01672     /* Backtrace and make an iterator, this should look familiar by now. */
01673     itor = ckd_calloc(1, sizeof(*itor));
01674     itor->base.vt = &ps_astar_segfuncs;
01675     itor->base.search = astar->dag->search;
01676     itor->base.lwf = lwf;
01677     itor->n_nodes = itor->cur = 0;
01678     for (p = path; p; p = p->parent) {
01679         ++itor->n_nodes;
01680     }
01681     itor->nodes = ckd_calloc(itor->n_nodes, sizeof(*itor->nodes));
01682     cur = itor->n_nodes - 1;
01683     for (p = path; p; p = p->parent) {
01684         itor->nodes[cur] = p->node;
01685         --cur;
01686     }
01687 
01688     ps_astar_node2itor(itor);
01689     return (ps_seg_t *)itor;
01690 }
01691 
01692 void
01693 ps_astar_finish(ps_astar_t *nbest)
01694 {
01695     gnode_t *gn;
01696 
01697     /* Free all hyps. */
01698     for (gn = nbest->hyps; gn; gn = gnode_next(gn)) {
01699         ckd_free(gnode_ptr(gn));
01700     }
01701     glist_free(nbest->hyps);
01702     /* Free all paths. */
01703     listelem_alloc_free(nbest->latpath_alloc);
01704     /* Free the Henge. */
01705     ckd_free(nbest);
01706 }

Generated on Sat Jan 8 2011 for PocketSphinx by  doxygen 1.7.1