SphinxBase 0.6

src/libsphinxbase/util/huff_code.c

00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 2009 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 
00038 #include <string.h>
00039 
00040 #include "sphinxbase/huff_code.h"
00041 #include "sphinxbase/ckd_alloc.h"
00042 #include "sphinxbase/hash_table.h"
00043 #include "sphinxbase/byteorder.h"
00044 #include "sphinxbase/heap.h"
00045 #include "sphinxbase/pio.h"
00046 #include "sphinxbase/err.h"
00047 
00048 typedef struct huff_node_s {
00049     int nbits;
00050     struct huff_node_s *l;
00051     union {
00052         int32 ival;
00053         char *sval;
00054         struct huff_node_s *r;
00055     } r;
00056 } huff_node_t;
00057 
00058 typedef struct huff_codeword_s {
00059     union {
00060         int32 ival;
00061         char *sval;
00062     } r;
00063     uint32 nbits, codeword;
00064 } huff_codeword_t;
00065 
00066 enum {
00067     HUFF_CODE_INT,
00068     HUFF_CODE_STR
00069 };
00070 
00071 struct huff_code_s {
00072     int16 refcount;
00073     uint8 maxbits;
00074     uint8 type;
00075     uint32 *firstcode;
00076     uint32 *numl;
00077     huff_codeword_t **syms;
00078     hash_table_t *codewords;
00079     FILE *fh;
00080     bit_encode_t *be;
00081     int boff;
00082 };
00083 
00084 static huff_node_t *
00085 huff_node_new_int(int32 val)
00086 {
00087     huff_node_t *hn = ckd_calloc(1, sizeof(*hn));
00088     hn->r.ival = val;
00089     return hn;
00090 }
00091 
00092 static huff_node_t *
00093 huff_node_new_str(char const *val)
00094 {
00095     huff_node_t *hn = ckd_calloc(1, sizeof(*hn));
00096     hn->r.sval = ckd_salloc(val);
00097     return hn;
00098 }
00099 
00100 static huff_node_t *
00101 huff_node_new_parent(huff_node_t *l, huff_node_t *r)
00102 {
00103     huff_node_t *hn = ckd_calloc(1, sizeof(*hn));
00104     hn->l = l;
00105     hn->r.r = r;
00106     /* Propagate maximum bit length. */
00107     if (r->nbits > l->nbits)
00108         hn->nbits = r->nbits + 1;
00109     else
00110         hn->nbits = l->nbits + 1;
00111     return hn;
00112 }
00113 
00114 static void
00115 huff_node_free_int(huff_node_t *root)
00116 {
00117     if (root->l) {
00118         huff_node_free_int(root->l);
00119         huff_node_free_int(root->r.r);
00120     }
00121     ckd_free(root);
00122 }
00123 
00124 static void
00125 huff_node_free_str(huff_node_t *root, int freestr)
00126 {
00127     if (root->l) {
00128         huff_node_free_str(root->l, freestr);
00129         huff_node_free_str(root->r.r, freestr);
00130     }
00131     else {
00132         if (freestr)
00133             ckd_free(root->r.sval);
00134     }
00135     ckd_free(root);
00136 }
00137 
00138 static huff_node_t *
00139 huff_code_build_tree(heap_t *q)
00140 {
00141     huff_node_t *root = NULL;
00142     int32 rf;
00143 
00144     while (heap_size(q) > 1) {
00145         huff_node_t *l, *r, *p;
00146         int32 lf, rf;
00147 
00148         heap_pop(q, (void *)&l, &lf);
00149         heap_pop(q, (void *)&r, &rf);
00150         p = huff_node_new_parent(l, r);
00151         heap_insert(q, p, lf + rf);
00152     }
00153     heap_pop(q, (void **)&root, &rf);
00154     return root;
00155 }
00156 
00157 static void
00158 huff_code_canonicalize(huff_code_t *hc, huff_node_t *root)
00159 {
00160     glist_t agenda;
00161     uint32 *nextcode;
00162     int i, ncw;
00163 
00164     hc->firstcode = ckd_calloc(hc->maxbits+1, sizeof(*hc->firstcode));
00165     hc->syms = ckd_calloc(hc->maxbits+1, sizeof(*hc->syms));
00166     hc->numl = ckd_calloc(hc->maxbits+1, sizeof(*nextcode));
00167     nextcode = ckd_calloc(hc->maxbits+1, sizeof(*nextcode));
00168 
00169     /* Traverse the tree, annotating it with the actual bit
00170      * lengths, and histogramming them in numl. */
00171     root->nbits = 0;
00172     ncw = 0;
00173     agenda = glist_add_ptr(NULL, root);
00174     while (agenda) {
00175         huff_node_t *node = gnode_ptr(agenda);
00176         agenda = gnode_free(agenda, NULL);
00177         if (node->l) {
00178             node->l->nbits = node->nbits + 1;
00179             agenda = glist_add_ptr(agenda, node->l);
00180             node->r.r->nbits = node->nbits + 1;
00181             agenda = glist_add_ptr(agenda, node->r.r);
00182         }
00183         else {
00184             hc->numl[node->nbits]++;
00185             ncw++;
00186         }
00187     }
00188     /* Create starting codes and symbol tables for each bit length. */
00189     hc->syms[hc->maxbits] = ckd_calloc(hc->numl[hc->maxbits], sizeof(**hc->syms));
00190     for (i = hc->maxbits - 1; i > 0; --i) {
00191         hc->firstcode[i] = (hc->firstcode[i+1] + hc->numl[i+1]) / 2;
00192         hc->syms[i] = ckd_calloc(hc->numl[i], sizeof(**hc->syms));
00193     }
00194     memcpy(nextcode, hc->firstcode, (hc->maxbits + 1) * sizeof(*nextcode));
00195     /* Traverse the tree again to produce the codebook itself. */
00196     hc->codewords = hash_table_new(ncw, HASH_CASE_YES);
00197     agenda = glist_add_ptr(NULL, root);
00198     while (agenda) {
00199         huff_node_t *node = gnode_ptr(agenda);
00200         agenda = gnode_free(agenda, NULL);
00201         if (node->l) {
00202             agenda = glist_add_ptr(agenda, node->l);
00203             agenda = glist_add_ptr(agenda, node->r.r);
00204         }
00205         else {
00206             /* Initialize codebook entry, which also retains symbol pointer. */
00207             huff_codeword_t *cw;
00208             uint32 codeword = nextcode[node->nbits] & ((1 << node->nbits) - 1);
00209             cw = hc->syms[node->nbits] + (codeword - hc->firstcode[node->nbits]);
00210             cw->nbits = node->nbits;
00211             cw->r.sval = node->r.sval; /* Will copy ints too... */
00212             cw->codeword = codeword;
00213             if (hc->type == HUFF_CODE_INT) {
00214                 hash_table_enter_bkey(hc->codewords,
00215                                       (char const *)&cw->r.ival,
00216                                       sizeof(cw->r.ival),
00217                                       (void *)cw);
00218             }
00219             else {
00220                 hash_table_enter(hc->codewords, cw->r.sval, (void *)cw);
00221             }
00222             ++nextcode[node->nbits];
00223         }
00224     }
00225     ckd_free(nextcode);
00226 }
00227 
00228 huff_code_t *
00229 huff_code_build_int(int32 const *values, int32 const *frequencies, int nvals)
00230 {
00231     huff_code_t *hc;
00232     huff_node_t *root;
00233     heap_t *q;
00234     int i;
00235 
00236     hc = ckd_calloc(1, sizeof(*hc));
00237     hc->refcount = 1;
00238     hc->type = HUFF_CODE_INT;
00239 
00240     /* Initialize the heap with nodes for each symbol. */
00241     q = heap_new();
00242     for (i = 0; i < nvals; ++i) {
00243         heap_insert(q,
00244                     huff_node_new_int(values[i]),
00245                     frequencies[i]);
00246     }
00247 
00248     /* Now build the tree, which gives us codeword lengths. */
00249     root = huff_code_build_tree(q);
00250     heap_destroy(q);
00251     if (root == NULL || root->nbits > 32) {
00252         E_ERROR("Huffman trees currently limited to 32 bits\n");
00253         huff_node_free_int(root);
00254         huff_code_free(hc);
00255         return NULL;
00256     }
00257 
00258     /* Build a canonical codebook. */
00259     hc->maxbits = root->nbits;
00260     huff_code_canonicalize(hc, root);
00261 
00262     /* Tree no longer needed. */
00263     huff_node_free_int(root);
00264 
00265     return hc;
00266 }
00267 
00268 huff_code_t *
00269 huff_code_build_str(char * const *values, int32 const *frequencies, int nvals)
00270 {
00271     huff_code_t *hc;
00272     huff_node_t *root;
00273     heap_t *q;
00274     int i;
00275 
00276     hc = ckd_calloc(1, sizeof(*hc));
00277     hc->refcount = 1;
00278     hc->type = HUFF_CODE_STR;
00279 
00280     /* Initialize the heap with nodes for each symbol. */
00281     q = heap_new();
00282     for (i = 0; i < nvals; ++i) {
00283         heap_insert(q,
00284                     huff_node_new_str(values[i]),
00285                     frequencies[i]);
00286     }
00287 
00288     /* Now build the tree, which gives us codeword lengths. */
00289     root = huff_code_build_tree(q);
00290     heap_destroy(q);
00291     if (root == NULL || root->nbits > 32) {
00292         E_ERROR("Huffman trees currently limited to 32 bits\n");
00293         huff_node_free_str(root, TRUE);
00294         huff_code_free(hc);
00295         return NULL;
00296     }
00297 
00298     /* Build a canonical codebook. */
00299     hc->maxbits = root->nbits;
00300     huff_code_canonicalize(hc, root);
00301 
00302     /* Tree no longer needed (note we retain pointers to its strings). */
00303     huff_node_free_str(root, FALSE);
00304 
00305     return hc;
00306 }
00307 
00308 huff_code_t *
00309 huff_code_read(FILE *infh)
00310 {
00311     huff_code_t *hc;
00312     int i, j;
00313 
00314     hc = ckd_calloc(1, sizeof(*hc));
00315     hc->refcount = 1;
00316 
00317     hc->maxbits = fgetc(infh);
00318     hc->type = fgetc(infh);
00319 
00320     /* Two bytes of padding. */
00321     fgetc(infh);
00322     fgetc(infh);
00323 
00324     /* Allocate stuff. */
00325     hc->firstcode = ckd_calloc(hc->maxbits + 1, sizeof(*hc->firstcode));
00326     hc->numl = ckd_calloc(hc->maxbits + 1, sizeof(*hc->numl));
00327     hc->syms = ckd_calloc(hc->maxbits + 1, sizeof(*hc->syms));
00328 
00329     /* Read the symbol tables. */
00330     hc->codewords = hash_table_new(hc->maxbits, HASH_CASE_YES);
00331     for (i = 1; i <= hc->maxbits; ++i) {
00332         if (fread(&hc->firstcode[i], 4, 1, infh) != 1)
00333             goto error_out;
00334         SWAP_BE_32(&hc->firstcode[i]);
00335         if (fread(&hc->numl[i], 4, 1, infh) != 1)
00336             goto error_out;
00337         SWAP_BE_32(&hc->numl[i]);
00338         hc->syms[i] = ckd_calloc(hc->numl[i], sizeof(**hc->syms));
00339         for (j = 0; j < hc->numl[i]; ++j) {
00340             huff_codeword_t *cw = &hc->syms[i][j];
00341             cw->nbits = i;
00342             cw->codeword = hc->firstcode[i] + j;
00343             if (hc->type == HUFF_CODE_INT) {
00344                 if (fread(&cw->r.ival, 4, 1, infh) != 1)
00345                     goto error_out;
00346                 SWAP_BE_32(&cw->r.ival);
00347                 hash_table_enter_bkey(hc->codewords,
00348                                       (char const *)&cw->r.ival,
00349                                       sizeof(cw->r.ival),
00350                                       (void *)cw);
00351             }
00352             else {
00353                 size_t len;
00354                 cw->r.sval = fread_line(infh, &len);
00355                 cw->r.sval[len-1] = '\0';
00356                 hash_table_enter(hc->codewords, cw->r.sval, (void *)cw);
00357             }
00358         }
00359     }
00360 
00361     return hc;
00362 error_out:
00363     huff_code_free(hc);
00364     return NULL;
00365 }
00366 
00367 int
00368 huff_code_write(huff_code_t *hc, FILE *outfh)
00369 {
00370     int i, j;
00371 
00372     /* Maximum codeword length */
00373     fputc(hc->maxbits, outfh);
00374     /* Symbol type */
00375     fputc(hc->type, outfh);
00376     /* Two extra bytes (for future use and alignment) */
00377     fputc(0, outfh);
00378     fputc(0, outfh);
00379     /* For each codeword length: */
00380     for (i = 1; i <= hc->maxbits; ++i) {
00381         uint32 val;
00382 
00383         /* Starting code, number of codes. */
00384         val = hc->firstcode[i];
00385         /* Canonically big-endian (like the data itself) */
00386         SWAP_BE_32(&val);
00387         fwrite(&val, 4, 1, outfh);
00388         val = hc->numl[i];
00389         SWAP_BE_32(&val);
00390         fwrite(&val, 4, 1, outfh);
00391 
00392         /* Symbols for each code (FIXME: Should compress these too) */
00393         for (j = 0; j < hc->numl[i]; ++j) {
00394             if (hc->type == HUFF_CODE_INT) {
00395                 int32 val = hc->syms[i][j].r.ival;
00396                 SWAP_BE_32(&val);
00397                 fwrite(&val, 4, 1, outfh);
00398             }
00399             else {
00400                 /* Write them all separated by newlines, so that
00401                  * fgets() will read them for us. */
00402                 fprintf(outfh, "%s\n", hc->syms[i][j].r.sval);
00403             }
00404         }
00405     }
00406     return 0;
00407 }
00408 
00409 int
00410 huff_code_dump_codebits(FILE *dumpfh, uint32 nbits, uint32 codeword)
00411 {
00412     uint32 i;
00413 
00414     for (i = 0; i < nbits; ++i)
00415         fputc((codeword & (1<<(nbits-i-1))) ? '1' : '0', dumpfh);
00416     return 0;
00417 }
00418 
00419 int
00420 huff_code_dump(huff_code_t *hc, FILE *dumpfh)
00421 {
00422     int i, j;
00423 
00424     /* Print out all codewords. */
00425     fprintf(dumpfh, "Maximum codeword length: %d\n", hc->maxbits);
00426     fprintf(dumpfh, "Symbols are %s\n", (hc->type == HUFF_CODE_STR) ? "strings" : "ints");
00427     fprintf(dumpfh, "Codewords:\n");
00428     for (i = 1; i <= hc->maxbits; ++i) {
00429         for (j = 0; j < hc->numl[i]; ++j) {
00430             if (hc->type == HUFF_CODE_STR)
00431                 fprintf(dumpfh, "%-30s", hc->syms[i][j].r.sval);
00432             else
00433                 fprintf(dumpfh, "%-30d", hc->syms[i][j].r.ival);
00434             huff_code_dump_codebits(dumpfh, hc->syms[i][j].nbits,
00435                                     hc->syms[i][j].codeword);
00436             fprintf(dumpfh, "\n");
00437         }
00438     }
00439     return 0;
00440 }
00441 
00442 huff_code_t *
00443 huff_code_retain(huff_code_t *hc)
00444 {
00445     ++hc->refcount;
00446     return hc;
00447 }
00448 
00449 int
00450 huff_code_free(huff_code_t *hc)
00451 {
00452     int i;
00453 
00454     if (hc == NULL)
00455         return 0;
00456     if (--hc->refcount > 0)
00457         return hc->refcount;
00458     for (i = 0; i <= hc->maxbits; ++i) {
00459         int j;
00460         for (j = 0; j < hc->numl[i]; ++j) {
00461             if (hc->type == HUFF_CODE_STR)
00462                 ckd_free(hc->syms[i][j].r.sval);
00463         }
00464         ckd_free(hc->syms[i]);
00465     }
00466     ckd_free(hc->firstcode);
00467     ckd_free(hc->numl);
00468     ckd_free(hc->syms);
00469     hash_table_free(hc->codewords);
00470     ckd_free(hc);
00471     return 0;
00472 }
00473 
00474 FILE *
00475 huff_code_attach(huff_code_t *hc, FILE *fh, char const *mode)
00476 {
00477     FILE *oldfh = huff_code_detach(hc);
00478 
00479     hc->fh = fh;
00480     if (mode[0] == 'w')
00481         hc->be = bit_encode_attach(hc->fh);
00482     return oldfh;
00483 }
00484 
00485 FILE *
00486 huff_code_detach(huff_code_t *hc)
00487 {
00488     FILE *oldfh = hc->fh;
00489         
00490     if (hc->be) {
00491         bit_encode_flush(hc->be);
00492         bit_encode_free(hc->be);
00493         hc->be = NULL;
00494     }
00495     hc->fh = NULL;
00496     return oldfh;
00497 }
00498 
00499 int
00500 huff_code_encode_int(huff_code_t *hc, int32 sym, uint32 *outcw)
00501 {
00502     huff_codeword_t *cw;
00503 
00504     if (hash_table_lookup_bkey(hc->codewords,
00505                                (char const *)&sym,
00506                                sizeof(sym),
00507                                (void **)&cw) < 0)
00508         return 0;
00509     if (hc->be)
00510         bit_encode_write_cw(hc->be, cw->codeword, cw->nbits);
00511     if (outcw) *outcw = cw->codeword;
00512     return cw->nbits;
00513 }
00514 
00515 int
00516 huff_code_encode_str(huff_code_t *hc, char const *sym, uint32 *outcw)
00517 {
00518     huff_codeword_t *cw;
00519 
00520     if (hash_table_lookup(hc->codewords,
00521                           sym,
00522                           (void **)&cw) < 0)
00523         return 0;
00524     if (hc->be)
00525         bit_encode_write_cw(hc->be, cw->codeword, cw->nbits);
00526     if (outcw) *outcw = cw->codeword;
00527     return cw->nbits;
00528 }
00529 
00530 static huff_codeword_t *
00531 huff_code_decode_data(huff_code_t *hc, char const **inout_data,
00532                       size_t *inout_data_len, int *inout_offset)
00533 {
00534     char const *data = *inout_data;
00535     char const *end = data + *inout_data_len;
00536     int offset = *inout_offset;
00537     uint32 cw;
00538     int cwlen;
00539     int byte;
00540 
00541     if (data == end)
00542         return NULL;
00543     byte = *data++;
00544     cw = !!(byte & (1 << (7-offset++)));
00545     cwlen = 1;
00546     /* printf("%.*x ", cwlen, cw); */
00547     while (cwlen <= hc->maxbits && cw < hc->firstcode[cwlen]) {
00548         ++cwlen;
00549         cw <<= 1;
00550         if (offset > 7) {
00551             if (data == end)
00552                 return NULL;
00553             byte = *data++;
00554             offset = 0;
00555         }
00556         cw |= !!(byte & (1 << (7-offset++)));
00557         /* printf("%.*x ", cwlen, cw); */
00558     }
00559     if (cwlen > hc->maxbits) /* FAIL: invalid data */
00560         return NULL;
00561 
00562     /* Put the last byte back if there are bits left over. */
00563     if (offset < 8)
00564         --data;
00565     else
00566         offset = 0;
00567 
00568     /* printf("%.*x\n", cwlen, cw); */
00569     *inout_data_len = end - data;
00570     *inout_data = data;
00571     *inout_offset = offset;
00572     return hc->syms[cwlen] + (cw - hc->firstcode[cwlen]);
00573 }
00574 
00575 static huff_codeword_t *
00576 huff_code_decode_fh(huff_code_t *hc)
00577 {
00578     uint32 cw;
00579     int cwlen;
00580     int byte;
00581 
00582     if ((byte = fgetc(hc->fh)) == EOF)
00583         return NULL;
00584     cw = !!(byte & (1 << (7-hc->boff++)));
00585     cwlen = 1;
00586     /* printf("%.*x ", cwlen, cw); */
00587     while (cwlen <= hc->maxbits && cw < hc->firstcode[cwlen]) {
00588         ++cwlen;
00589         cw <<= 1;
00590         if (hc->boff > 7) {
00591             if ((byte = fgetc(hc->fh)) == EOF)
00592                 return NULL;
00593             hc->boff = 0;
00594         }
00595         cw |= !!(byte & (1 << (7-hc->boff++)));
00596         /* printf("%.*x ", cwlen, cw); */
00597     }
00598     if (cwlen > hc->maxbits) /* FAIL: invalid data */
00599         return NULL;
00600 
00601     /* Put the last byte back if there are bits left over. */
00602     if (hc->boff < 8)
00603         ungetc(byte, hc->fh);
00604     else
00605         hc->boff = 0;
00606 
00607     /* printf("%.*x\n", cwlen, cw); */
00608     return hc->syms[cwlen] + (cw - hc->firstcode[cwlen]);
00609 }
00610 
00611 int
00612 huff_code_decode_int(huff_code_t *hc, int *outval,
00613                      char const **inout_data,
00614                      size_t *inout_data_len, int *inout_offset)
00615 {
00616     huff_codeword_t *cw;
00617 
00618     if (inout_data)
00619         cw = huff_code_decode_data(hc, inout_data, inout_data_len, inout_offset);
00620     else if (hc->fh)
00621         cw = huff_code_decode_fh(hc);
00622     else
00623         return -1;
00624 
00625     if (cw == NULL)
00626         return -1;
00627     if (outval)
00628         *outval = cw->r.ival;
00629 
00630     return 0;
00631 }
00632 
00633 char const *
00634 huff_code_decode_str(huff_code_t *hc,
00635                      char const **inout_data,
00636                      size_t *inout_data_len, int *inout_offset)
00637 {
00638     huff_codeword_t *cw;
00639 
00640     if (inout_data)
00641         cw = huff_code_decode_data(hc, inout_data, inout_data_len, inout_offset);
00642     else if (hc->fh)
00643         cw = huff_code_decode_fh(hc);
00644     else
00645         return NULL;
00646 
00647     if (cw == NULL)
00648         return NULL;
00649 
00650     return cw->r.sval;
00651 }