SphinxBase 0.6

src/libsphinxbase/util/logmath.c

00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 1999-2007 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00038 #include <math.h>
00039 #include <string.h>
00040 #include <assert.h>
00041 
00042 #include "sphinxbase/logmath.h"
00043 #include "sphinxbase/err.h"
00044 #include "sphinxbase/ckd_alloc.h"
00045 #include "sphinxbase/mmio.h"
00046 #include "sphinxbase/bio.h"
00047 #include "sphinxbase/strfuncs.h"
00048 
00049 struct logmath_s {
00050     logadd_t t;
00051     int refcount;
00052     mmio_file_t *filemap;
00053     float64 base;
00054     float64 log_of_base;
00055     float64 log10_of_base;
00056     float64 inv_log_of_base;
00057     float64 inv_log10_of_base;
00058     int32 zero;
00059 };
00060 
00061 logmath_t *
00062 logmath_init(float64 base, int shift, int use_table)
00063 {
00064     logmath_t *lmath;
00065     uint32 maxyx, i;
00066     float64 byx;
00067     int width;
00068 
00069     /* Check that the base is correct. */
00070     if (base <= 1.0) {
00071         E_ERROR("Base must be greater than 1.0\n");
00072         return NULL;
00073     }
00074     
00075     /* Set up various necessary constants. */
00076     lmath = ckd_calloc(1, sizeof(*lmath));
00077     lmath->refcount = 1;
00078     lmath->base = base;
00079     lmath->log_of_base = log(base);
00080     lmath->log10_of_base = log10(base);
00081     lmath->inv_log_of_base = 1.0/lmath->log_of_base;
00082     lmath->inv_log10_of_base = 1.0/lmath->log10_of_base;
00083     lmath->t.shift = shift;
00084     /* Shift this sufficiently that overflows can be avoided. */
00085     lmath->zero = MAX_NEG_INT32 >> (shift + 2);
00086 
00087     if (!use_table)
00088         return lmath;
00089 
00090     /* Create a logadd table with the appropriate width */
00091     maxyx = (uint32) (log(2.0) / log(base) + 0.5) >> shift;
00092     /* Poor man's log2 */
00093     if (maxyx < 256) width = 1;
00094     else if (maxyx < 65536) width = 2;
00095     else width = 4;
00096 
00097     lmath->t.width = width;
00098     /* Figure out size of add table required. */
00099     byx = 1.0; /* Maximum possible base^{y-x} value - note that this implies that y-x == 0 */
00100     for (i = 0;; ++i) {
00101         float64 lobyx = log(1.0 + byx) * lmath->inv_log_of_base; /* log_{base}(1 + base^{y-x}); */
00102         int32 k = (int32) (lobyx + 0.5 * (1<<shift)) >> shift; /* Round to shift */
00103 
00104         /* base^{y-x} has reached the smallest representable value. */
00105         if (k <= 0)
00106             break;
00107 
00108         /* This table is indexed by -(y-x), so we multiply byx by
00109          * base^{-1} here which is equivalent to subtracting one from
00110          * (y-x). */
00111         byx /= base;
00112     }
00113     i >>= shift;
00114 
00115     /* Never produce a table smaller than 256 entries. */
00116     if (i < 255) i = 255;
00117 
00118     lmath->t.table = ckd_calloc(i+1, width);
00119     lmath->t.table_size = i + 1;
00120     /* Create the add table (see above). */
00121     byx = 1.0;
00122     for (i = 0;; ++i) {
00123         float64 lobyx = log(1.0 + byx) * lmath->inv_log_of_base;
00124         int32 k = (int32) (lobyx + 0.5 * (1<<shift)) >> shift; /* Round to shift */
00125         uint32 prev = 0;
00126 
00127         /* Check any previous value - if there is a shift, we want to
00128          * only store the highest one. */
00129         switch (width) {
00130         case 1:
00131             prev = ((uint8 *)lmath->t.table)[i >> shift];
00132             break;
00133         case 2:
00134             prev = ((uint16 *)lmath->t.table)[i >> shift];
00135             break;
00136         case 4:
00137             prev = ((uint32 *)lmath->t.table)[i >> shift];
00138             break;
00139         }
00140         if (prev == 0) {
00141             switch (width) {
00142             case 1:
00143                 ((uint8 *)lmath->t.table)[i >> shift] = (uint8) k;
00144                 break;
00145             case 2:
00146                 ((uint16 *)lmath->t.table)[i >> shift] = (uint16) k;
00147                 break;
00148             case 4:
00149                 ((uint32 *)lmath->t.table)[i >> shift] = (uint32) k;
00150                 break;
00151             }
00152         }
00153         if (k <= 0)
00154             break;
00155 
00156         /* Decay base^{y-x} exponentially according to base. */
00157         byx /= base;
00158     }
00159 
00160     return lmath;
00161 }
00162 
00163 logmath_t *
00164 logmath_read(const char *file_name)
00165 {
00166     logmath_t *lmath;
00167     char **argname, **argval;
00168     int32 byteswap, i;
00169     int chksum_present, do_mmap;
00170     uint32 chksum;
00171     long pos;
00172     FILE *fp;
00173 
00174     E_INFO("Reading log table file '%s'\n", file_name);
00175     if ((fp = fopen(file_name, "rb")) == NULL) {
00176         E_ERROR("Failed to open log table file '%s' for reading: %s\n", file_name, strerror(errno));
00177         return NULL;
00178     }
00179 
00180     /* Read header, including argument-value info and 32-bit byteorder magic */
00181     if (bio_readhdr(fp, &argname, &argval, &byteswap) < 0) {
00182         E_ERROR("bio_readhdr(%s) failed\n", file_name);
00183         fclose(fp);
00184         return NULL;
00185     }
00186 
00187     lmath = ckd_calloc(1, sizeof(*lmath));
00188     /* Default values. */
00189     lmath->t.shift = 0;
00190     lmath->t.width = 2;
00191     lmath->base = 1.0001;
00192 
00193     /* Parse argument-value list */
00194     chksum_present = 0;
00195     for (i = 0; argname[i]; i++) {
00196         if (strcmp(argname[i], "version") == 0) {
00197         }
00198         else if (strcmp(argname[i], "chksum0") == 0) {
00199             if (strcmp(argval[i], "yes") == 0)
00200                 chksum_present = 1;
00201         }
00202         else if (strcmp(argname[i], "width") == 0) {
00203             lmath->t.width = atoi(argval[i]);
00204         }
00205         else if (strcmp(argname[i], "shift") == 0) {
00206             lmath->t.shift = atoi(argval[i]);
00207         }
00208         else if (strcmp(argname[i], "logbase") == 0) {
00209             lmath->base = atof_c(argval[i]);
00210         }
00211     }
00212     bio_hdrarg_free(argname, argval);
00213     chksum = 0;
00214 
00215     /* Set up various necessary constants. */
00216     lmath->log_of_base = log(lmath->base);
00217     lmath->log10_of_base = log10(lmath->base);
00218     lmath->inv_log_of_base = 1.0/lmath->log_of_base;
00219     lmath->inv_log10_of_base = 1.0/lmath->log10_of_base;
00220     /* Shift this sufficiently that overflows can be avoided. */
00221     lmath->zero = MAX_NEG_INT32 >> (lmath->t.shift + 2);
00222 
00223     /* #Values to follow */
00224     if (bio_fread(&lmath->t.table_size, sizeof(int32), 1, fp, byteswap, &chksum) != 1) {
00225         E_ERROR("fread(%s) (total #values) failed\n", file_name);
00226         goto error_out;
00227     }
00228 
00229     /* Check alignment constraints for memory mapping */
00230     do_mmap = 1;
00231     pos = ftell(fp);
00232     if (pos & ((long)lmath->t.width - 1)) {
00233         E_WARN("%s: Data start %ld is not aligned on %d-byte boundary, will not memory map\n",
00234                   file_name, pos, lmath->t.width);
00235         do_mmap = 0;
00236     }
00237     /* Check byte order for memory mapping */
00238     if (byteswap) {
00239         E_WARN("%s: Data is wrong-endian, will not memory map\n", file_name);
00240         do_mmap = 0;
00241     }
00242 
00243     if (do_mmap) {
00244         lmath->filemap = mmio_file_read(file_name);
00245         lmath->t.table = (char *)mmio_file_ptr(lmath->filemap) + pos;
00246     }
00247     else {
00248         lmath->t.table = ckd_calloc(lmath->t.table_size, lmath->t.width);
00249         if (bio_fread(lmath->t.table, lmath->t.width, lmath->t.table_size,
00250                       fp, byteswap, &chksum) != lmath->t.table_size) {
00251             E_ERROR("fread(%s) (%d x %d bytes) failed\n",
00252                     file_name, lmath->t.table_size, lmath->t.width);
00253             goto error_out;
00254         }
00255         if (chksum_present)
00256             bio_verify_chksum(fp, byteswap, chksum);
00257 
00258         if (fread(&i, 1, 1, fp) == 1) {
00259             E_ERROR("%s: More data than expected\n", file_name);
00260             goto error_out;
00261         }
00262     }
00263     fclose(fp);
00264 
00265     return lmath;
00266 error_out:
00267     logmath_free(lmath);
00268     return NULL;
00269 }
00270 
00271 int32
00272 logmath_write(logmath_t *lmath, const char *file_name)
00273 {
00274     FILE *fp;
00275     long pos;
00276     uint32 chksum;
00277 
00278     if (lmath->t.table == NULL) {
00279         E_ERROR("No log table to write!\n");
00280         return -1;
00281     }
00282 
00283     E_INFO("Writing log table file '%s'\n", file_name);
00284     if ((fp = fopen(file_name, "wb")) == NULL) {
00285         E_ERROR("Failed to open logtable file '%s' for writing: %s\n", file_name, strerror(errno));
00286         return -1;
00287     }
00288 
00289     /* For whatever reason, we have to do this manually at the
00290      * moment. */
00291     fprintf(fp, "s3\nversion 1.0\nchksum0 yes\n");
00292     fprintf(fp, "width %d\n", lmath->t.width);
00293     fprintf(fp, "shift %d\n", lmath->t.shift);
00294     fprintf(fp, "logbase %f\n", lmath->base);
00295     /* Pad it out to ensure alignment. */
00296     pos = ftell(fp) + strlen("endhdr\n");
00297     if (pos & ((long)lmath->t.width - 1)) {
00298         size_t align = lmath->t.width - (pos & ((long)lmath->t.width - 1));
00299         assert(lmath->t.width <= 8);
00300         fwrite("        " /* 8 spaces */, 1, align, fp);
00301     }
00302     fprintf(fp, "endhdr\n");
00303 
00304     /* Now write the binary data. */
00305     chksum = (uint32)BYTE_ORDER_MAGIC;
00306     fwrite(&chksum, sizeof(uint32), 1, fp);
00307     chksum = 0;
00308     /* #Values to follow */
00309     if (bio_fwrite(&lmath->t.table_size, sizeof(uint32),
00310                    1, fp, 0, &chksum) != 1) {
00311         E_ERROR("fwrite(%s) (total #values) failed\n", file_name);
00312         goto error_out;
00313     }
00314 
00315     if (bio_fwrite(lmath->t.table, lmath->t.width, lmath->t.table_size,
00316                    fp, 0, &chksum) != lmath->t.table_size) {
00317         E_ERROR("fwrite(%s) (%d x %d bytes) failed\n",
00318                 file_name, lmath->t.table_size, lmath->t.width);
00319         goto error_out;
00320     }
00321     if (bio_fwrite(&chksum, sizeof(uint32), 1, fp, 0, NULL) != 1) {
00322         E_ERROR("fwrite(%s) checksum failed\n", file_name);
00323         goto error_out;
00324     }
00325 
00326     fclose(fp);
00327     return 0;
00328 
00329 error_out:
00330     fclose(fp);
00331     return -1;
00332 }
00333 
00334 logmath_t *
00335 logmath_retain(logmath_t *lmath)
00336 {
00337     ++lmath->refcount;
00338     return lmath;
00339 }
00340 
00341 int
00342 logmath_free(logmath_t *lmath)
00343 {
00344     if (lmath == NULL)
00345         return 0;
00346     if (--lmath->refcount > 0)
00347         return lmath->refcount;
00348     if (lmath->filemap)
00349         mmio_file_unmap(lmath->filemap);
00350     else
00351         ckd_free(lmath->t.table);
00352     ckd_free(lmath);
00353     return 0;
00354 }
00355 
00356 int32
00357 logmath_get_table_shape(logmath_t *lmath, uint32 *out_size,
00358                         uint32 *out_width, uint32 *out_shift)
00359 {
00360     if (out_size) *out_size = lmath->t.table_size;
00361     if (out_width) *out_width = lmath->t.width;
00362     if (out_shift) *out_shift = lmath->t.shift;
00363 
00364     return lmath->t.table_size * lmath->t.width;
00365 }
00366 
00367 float64
00368 logmath_get_base(logmath_t *lmath)
00369 {
00370     return lmath->base;
00371 }
00372 
00373 int
00374 logmath_get_zero(logmath_t *lmath)
00375 {
00376     return lmath->zero;
00377 }
00378 
00379 int
00380 logmath_get_width(logmath_t *lmath)
00381 {
00382     return lmath->t.width;
00383 }
00384 
00385 int
00386 logmath_get_shift(logmath_t *lmath)
00387 {
00388     return lmath->t.shift;
00389 }
00390 
00391 int
00392 logmath_add(logmath_t *lmath, int logb_x, int logb_y)
00393 {
00394     logadd_t *t = LOGMATH_TABLE(lmath);
00395     int d, r;
00396 
00397     /* handle 0 + x = x case. */
00398     if (logb_x <= lmath->zero)
00399         return logb_y;
00400     if (logb_y <= lmath->zero)
00401         return logb_x;
00402 
00403     if (t->table == NULL)
00404         return logmath_add_exact(lmath, logb_x, logb_y);
00405 
00406     /* d must be positive, obviously. */
00407     if (logb_x > logb_y) {
00408         d = (logb_x - logb_y);
00409         r = logb_x;
00410     }
00411     else {
00412         d = (logb_y - logb_x);
00413         r = logb_y;
00414     }
00415 
00416     if (d < 0) {
00417         /* Some kind of overflow has occurred, fail gracefully. */
00418         return r;
00419     }
00420     if ((size_t)d >= t->table_size) {
00421         /* If this happens, it's not actually an error, because the
00422          * last entry in the logadd table is guaranteed to be zero.
00423          * Therefore we just return the larger of the two values. */
00424         return r;
00425     }
00426 
00427     switch (t->width) {
00428     case 1:
00429         return r + (((uint8 *)t->table)[d]);
00430     case 2:
00431         return r + (((uint16 *)t->table)[d]);
00432     case 4:
00433         return r + (((uint32 *)t->table)[d]);
00434     }
00435     return r;
00436 }
00437 
00438 int
00439 logmath_add_exact(logmath_t *lmath, int logb_p, int logb_q)
00440 {
00441     return logmath_log(lmath,
00442                        logmath_exp(lmath, logb_p)
00443                        + logmath_exp(lmath, logb_q));
00444 }
00445 
00446 int
00447 logmath_log(logmath_t *lmath, float64 p)
00448 {
00449     if (p <= 0) {
00450         return lmath->zero;
00451     }
00452     return (int)(log(p) * lmath->inv_log_of_base) >> lmath->t.shift;
00453 }
00454 
00455 float64
00456 logmath_exp(logmath_t *lmath, int logb_p)
00457 {
00458     return pow(lmath->base, (float64)(logb_p << lmath->t.shift));
00459 }
00460 
00461 int
00462 logmath_ln_to_log(logmath_t *lmath, float64 log_p)
00463 {
00464     return (int)(log_p * lmath->inv_log_of_base) >> lmath->t.shift;
00465 }
00466 
00467 float64
00468 logmath_log_to_ln(logmath_t *lmath, int logb_p)
00469 {
00470     return (float64)(logb_p << lmath->t.shift) * lmath->log_of_base;
00471 }
00472 
00473 int
00474 logmath_log10_to_log(logmath_t *lmath, float64 log_p)
00475 {
00476     return (int)(log_p * lmath->inv_log10_of_base) >> lmath->t.shift;
00477 }
00478 
00479 float64
00480 logmath_log_to_log10(logmath_t *lmath, int logb_p)
00481 {
00482     return (float64)(logb_p << lmath->t.shift) * lmath->log10_of_base;
00483 }