| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // All Rights Reserved. |
| // |
| // Author : Johan Schalkwyk |
| // |
| // \file |
| // Classes to provide symbol-to-integer and integer-to-symbol mappings. |
| |
| #include <fst/symbol-table.h> |
| |
| #include <fst/util.h> |
| |
| DEFINE_bool(fst_compat_symbols, true, |
| "Require symbol tables to match when appropriate"); |
| DEFINE_string(fst_field_separator, "\t ", |
| "Set of characters used as a separator between printed fields"); |
| |
| namespace fst { |
| |
| // Maximum line length in textual symbols file. |
| const int kLineLen = 8096; |
| |
| // Identifies stream data as a symbol table (and its endianity) |
| static const int32 kSymbolTableMagicNumber = 2125658996; |
| |
| SymbolTableTextOptions::SymbolTableTextOptions() |
| : allow_negative(false), fst_field_separator(FLAGS_fst_field_separator) { } |
| |
| SymbolTableImpl* SymbolTableImpl::ReadText(istream &strm, |
| const string &filename, |
| const SymbolTableTextOptions &opts) { |
| SymbolTableImpl* impl = new SymbolTableImpl(filename); |
| |
| int64 nline = 0; |
| char line[kLineLen]; |
| while (strm.getline(line, kLineLen)) { |
| ++nline; |
| vector<char *> col; |
| string separator = opts.fst_field_separator + "\n"; |
| SplitToVector(line, separator.c_str(), &col, true); |
| if (col.size() == 0) // empty line |
| continue; |
| if (col.size() != 2) { |
| LOG(ERROR) << "SymbolTable::ReadText: Bad number of columns (" |
| << col.size() << "), " |
| << "file = " << filename << ", line = " << nline |
| << ":<" << line << ">"; |
| delete impl; |
| return 0; |
| } |
| const char *symbol = col[0]; |
| const char *value = col[1]; |
| char *p; |
| int64 key = strtoll(value, &p, 10); |
| if (p < value + strlen(value) || |
| (!opts.allow_negative && key < 0) || key == -1) { |
| LOG(ERROR) << "SymbolTable::ReadText: Bad non-negative integer \"" |
| << value << "\", " |
| << "file = " << filename << ", line = " << nline; |
| delete impl; |
| return 0; |
| } |
| impl->AddSymbol(symbol, key); |
| } |
| |
| return impl; |
| } |
| |
| void SymbolTableImpl::MaybeRecomputeCheckSum() const { |
| { |
| ReaderMutexLock check_sum_lock(&check_sum_mutex_); |
| if (check_sum_finalized_) |
| return; |
| } |
| |
| // We'll aquire an exclusive lock to recompute the checksums. |
| MutexLock check_sum_lock(&check_sum_mutex_); |
| if (check_sum_finalized_) // Another thread (coming in around the same time |
| return; // might have done it already). So we recheck. |
| |
| // Calculate the original label-agnostic check sum. |
| CheckSummer check_sum; |
| for (int64 i = 0; i < symbols_.size(); ++i) |
| check_sum.Update(symbols_[i], strlen(symbols_[i]) + 1); |
| check_sum_string_ = check_sum.Digest(); |
| |
| // Calculate the safer, label-dependent check sum. |
| CheckSummer labeled_check_sum; |
| for (int64 key = 0; key < dense_key_limit_; ++key) { |
| ostringstream line; |
| line << symbols_[key] << '\t' << key; |
| labeled_check_sum.Update(line.str().data(), line.str().size()); |
| } |
| for (map<int64, const char*>::const_iterator it = |
| key_map_.begin(); |
| it != key_map_.end(); |
| ++it) { |
| if (it->first >= dense_key_limit_) { |
| ostringstream line; |
| line << it->second << '\t' << it->first; |
| labeled_check_sum.Update(line.str().data(), line.str().size()); |
| } |
| } |
| labeled_check_sum_string_ = labeled_check_sum.Digest(); |
| |
| check_sum_finalized_ = true; |
| } |
| |
| int64 SymbolTableImpl::AddSymbol(const string& symbol, int64 key) { |
| map<const char *, int64, StrCmp>::const_iterator it = |
| symbol_map_.find(symbol.c_str()); |
| if (it == symbol_map_.end()) { // only add if not in table |
| check_sum_finalized_ = false; |
| |
| char *csymbol = new char[symbol.size() + 1]; |
| strcpy(csymbol, symbol.c_str()); |
| symbols_.push_back(csymbol); |
| key_map_[key] = csymbol; |
| symbol_map_[csymbol] = key; |
| |
| if (key >= available_key_) { |
| available_key_ = key + 1; |
| } |
| } else { |
| // Log if symbol already in table with different key |
| if (it->second != key) { |
| VLOG(1) << "SymbolTable::AddSymbol: symbol = " << symbol |
| << " already in symbol_map_ with key = " |
| << it->second |
| << " but supplied new key = " << key |
| << " (ignoring new key)"; |
| } |
| } |
| return key; |
| } |
| |
| static bool IsInRange(const vector<pair<int64, int64> >& ranges, |
| int64 key) { |
| if (ranges.size() == 0) return true; |
| for (size_t i = 0; i < ranges.size(); ++i) { |
| if (key >= ranges[i].first && key <= ranges[i].second) |
| return true; |
| } |
| return false; |
| } |
| |
| SymbolTableImpl* SymbolTableImpl::Read(istream &strm, |
| const SymbolTableReadOptions& opts) { |
| int32 magic_number = 0; |
| ReadType(strm, &magic_number); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Read: read failed"; |
| return 0; |
| } |
| string name; |
| ReadType(strm, &name); |
| SymbolTableImpl* impl = new SymbolTableImpl(name); |
| ReadType(strm, &impl->available_key_); |
| int64 size; |
| ReadType(strm, &size); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Read: read failed"; |
| delete impl; |
| return 0; |
| } |
| |
| string symbol; |
| int64 key; |
| impl->check_sum_finalized_ = false; |
| for (size_t i = 0; i < size; ++i) { |
| ReadType(strm, &symbol); |
| ReadType(strm, &key); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Read: read failed"; |
| delete impl; |
| return 0; |
| } |
| |
| char *csymbol = new char[symbol.size() + 1]; |
| strcpy(csymbol, symbol.c_str()); |
| impl->symbols_.push_back(csymbol); |
| if (key == impl->dense_key_limit_ && |
| key == impl->symbols_.size() - 1) |
| impl->dense_key_limit_ = impl->symbols_.size(); |
| else |
| impl->key_map_[key] = csymbol; |
| |
| if (IsInRange(opts.string_hash_ranges, key)) { |
| impl->symbol_map_[csymbol] = key; |
| } |
| } |
| return impl; |
| } |
| |
| bool SymbolTableImpl::Write(ostream &strm) const { |
| WriteType(strm, kSymbolTableMagicNumber); |
| WriteType(strm, name_); |
| WriteType(strm, available_key_); |
| int64 size = symbols_.size(); |
| WriteType(strm, size); |
| // first write out dense keys |
| int64 i = 0; |
| for (; i < dense_key_limit_; ++i) { |
| WriteType(strm, string(symbols_[i])); |
| WriteType(strm, i); |
| } |
| // next write out the remaining non densely packed keys |
| for (map<const char *, int64, StrCmp>::const_iterator it = |
| symbol_map_.begin(); it != symbol_map_.end(); ++it) { |
| if ((it->second >= 0) && (it->second < dense_key_limit_)) |
| continue; |
| WriteType(strm, string(it->first)); |
| WriteType(strm, it->second); |
| ++i; |
| } |
| if (i != size) { |
| LOG(ERROR) << "SymbolTable::Write: write failed"; |
| return false; |
| } |
| strm.flush(); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Write: write failed"; |
| return false; |
| } |
| return true; |
| } |
| |
| const int64 SymbolTable::kNoSymbol; |
| |
| |
| void SymbolTable::AddTable(const SymbolTable& table) { |
| for (SymbolTableIterator iter(table); !iter.Done(); iter.Next()) |
| impl_->AddSymbol(iter.Symbol()); |
| } |
| |
| bool SymbolTable::WriteText(ostream &strm, |
| const SymbolTableTextOptions &opts) const { |
| if (opts.fst_field_separator.empty()) { |
| LOG(ERROR) << "Missing required field separator"; |
| return false; |
| } |
| bool once_only = false; |
| for (SymbolTableIterator iter(*this); !iter.Done(); iter.Next()) { |
| ostringstream line; |
| if (iter.Value() < 0 && !opts.allow_negative && !once_only) { |
| LOG(WARNING) << "Negative symbol table entry when not allowed"; |
| once_only = true; |
| } |
| line << iter.Symbol() << opts.fst_field_separator[0] << iter.Value() |
| << '\n'; |
| strm.write(line.str().data(), line.str().length()); |
| } |
| return true; |
| } |
| } // namespace fst |