// algo_test.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Regression test for various FST algorithms. #ifndef FST_TEST_ALGO_TEST_H__ #define FST_TEST_ALGO_TEST_H__ #include <fst/fstlib.h> #include <fst/random-weight.h> DECLARE_int32(repeat); // defined in ./algo_test.cc namespace fst { // Mapper to change input and output label of every transition into // epsilons. template <class A> class EpsMapper { public: EpsMapper() {} A operator()(const A &arc) const { return A(0, 0, arc.weight, arc.nextstate); } uint64 Properties(uint64 props) const { props &= ~kNotAcceptor; props |= kAcceptor; props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons; props |= kIEpsilons | kOEpsilons | kEpsilons; props &= ~kNotILabelSorted & ~kNotOLabelSorted; props |= kILabelSorted | kOLabelSorted; return props; } MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS;} MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} }; // This class tests a variety of identities and properties that must // hold for various algorithms on weighted FSTs. template <class Arc, class WeightGenerator> class WeightedTester { public: typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; WeightedTester(int seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst, const Fst<Arc> &univ_fst, WeightGenerator *weight_generator) : seed_(seed), zero_fst_(zero_fst), one_fst_(one_fst), univ_fst_(univ_fst), weight_generator_(weight_generator) {} void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { TestRational(T1, T2, T3); TestMap(T1); TestCompose(T1, T2, T3); TestSort(T1); TestOptimize(T1); TestSearch(T1); } private: // Tests rational operations with identities void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { { VLOG(1) << "Check destructive and delayed union are equivalent."; VectorFst<Arc> U1(T1); Union(&U1, T2); UnionFst<Arc> U2(T1, T2); CHECK(Equiv(U1, U2)); } { VLOG(1) << "Check destructive and delayed concatenation are equivalent."; VectorFst<Arc> C1(T1); Concat(&C1, T2); ConcatFst<Arc> C2(T1, T2); CHECK(Equiv(C1, C2)); VectorFst<Arc> C3(T2); Concat(T1, &C3); CHECK(Equiv(C3, C2)); } { VLOG(1) << "Check destructive and delayed closure* are equivalent."; VectorFst<Arc> C1(T1); Closure(&C1, CLOSURE_STAR); ClosureFst<Arc> C2(T1, CLOSURE_STAR); CHECK(Equiv(C1, C2)); } { VLOG(1) << "Check destructive and delayed closure+ are equivalent."; VectorFst<Arc> C1(T1); Closure(&C1, CLOSURE_PLUS); ClosureFst<Arc> C2(T1, CLOSURE_PLUS); CHECK(Equiv(C1, C2)); } { VLOG(1) << "Check union is associative (destructive)."; VectorFst<Arc> U1(T1); Union(&U1, T2); Union(&U1, T3); VectorFst<Arc> U3(T2); Union(&U3, T3); VectorFst<Arc> U4(T1); Union(&U4, U3); CHECK(Equiv(U1, U4)); } { VLOG(1) << "Check union is associative (delayed)."; UnionFst<Arc> U1(T1, T2); UnionFst<Arc> U2(U1, T3); UnionFst<Arc> U3(T2, T3); UnionFst<Arc> U4(T1, U3); CHECK(Equiv(U2, U4)); } { VLOG(1) << "Check union is associative (destructive delayed)."; UnionFst<Arc> U1(T1, T2); Union(&U1, T3); UnionFst<Arc> U3(T2, T3); UnionFst<Arc> U4(T1, U3); CHECK(Equiv(U1, U4)); } { VLOG(1) << "Check concatenation is associative (destructive)."; VectorFst<Arc> C1(T1); Concat(&C1, T2); Concat(&C1, T3); VectorFst<Arc> C3(T2); Concat(&C3, T3); VectorFst<Arc> C4(T1); Concat(&C4, C3); CHECK(Equiv(C1, C4)); } { VLOG(1) << "Check concatenation is associative (delayed)."; ConcatFst<Arc> C1(T1, T2); ConcatFst<Arc> C2(C1, T3); ConcatFst<Arc> C3(T2, T3); ConcatFst<Arc> C4(T1, C3); CHECK(Equiv(C2, C4)); } { VLOG(1) << "Check concatenation is associative (destructive delayed)."; ConcatFst<Arc> C1(T1, T2); Concat(&C1, T3); ConcatFst<Arc> C3(T2, T3); ConcatFst<Arc> C4(T1, C3); CHECK(Equiv(C1, C4)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes" << " over union (destructive)."; VectorFst<Arc> U1(T1); Union(&U1, T2); VectorFst<Arc> C1(T3); Concat(&C1, U1); VectorFst<Arc> C2(T3); Concat(&C2, T1); VectorFst<Arc> C3(T3); Concat(&C3, T2); VectorFst<Arc> U2(C2); Union(&U2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes" << " over union (destructive)."; VectorFst<Arc> U1(T1); Union(&U1, T2); VectorFst<Arc> C1(U1); Concat(&C1, T3); VectorFst<Arc> C2(T1); Concat(&C2, T3); VectorFst<Arc> C3(T2); Concat(&C3, T3); VectorFst<Arc> U2(C2); Union(&U2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check concatenation left distributes over union (delayed)."; UnionFst<Arc> U1(T1, T2); ConcatFst<Arc> C1(T3, U1); ConcatFst<Arc> C2(T3, T1); ConcatFst<Arc> C3(T3, T2); UnionFst<Arc> U2(C2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check concatenation right distributes over union (delayed)."; UnionFst<Arc> U1(T1, T2); ConcatFst<Arc> C1(U1, T3); ConcatFst<Arc> C2(T1, T3); ConcatFst<Arc> C3(T2, T3); UnionFst<Arc> U2(C2, C3); CHECK(Equiv(C1, U2)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (destructive)."; VectorFst<Arc> S(T1); Closure(&S, CLOSURE_STAR); VectorFst<Arc> C(T1); Concat(&C, S); VectorFst<Arc> P(T1); Closure(&P, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (destructive)."; VectorFst<Arc> S(T1); Closure(&S, CLOSURE_STAR); VectorFst<Arc> C(S); Concat(&C, T1); VectorFst<Arc> P(T1); Closure(&P, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kLeftSemiring) { VLOG(1) << "Check T T* == T+ (delayed)."; ClosureFst<Arc> S(T1, CLOSURE_STAR); ConcatFst<Arc> C(T1, S); ClosureFst<Arc> P(T1, CLOSURE_PLUS); CHECK(Equiv(C, P)); } if (Weight::Properties() & kRightSemiring) { VLOG(1) << "Check T* T == T+ (delayed)."; ClosureFst<Arc> S(T1, CLOSURE_STAR); ConcatFst<Arc> C(S, T1); ClosureFst<Arc> P(T1, CLOSURE_PLUS); CHECK(Equiv(C, P)); } } // Tests map-based operations. void TestMap(const Fst<Arc> &T) { { VLOG(1) << "Check destructive and delayed projection are equivalent."; VectorFst<Arc> P1(T); Project(&P1, PROJECT_INPUT); ProjectFst<Arc> P2(T, PROJECT_INPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check destructive and delayed inversion are equivalent."; VectorFst<Arc> I1(T); Invert(&I1); InvertFst<Arc> I2(T); CHECK(Equiv(I1, I2)); } { VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive)."; VectorFst<Arc> P1(T); VectorFst<Arc> I1(T); Project(&P1, PROJECT_INPUT); Invert(&I1); Project(&I1, PROJECT_OUTPUT); CHECK(Equiv(P1, I1)); } { VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive)."; VectorFst<Arc> P1(T); VectorFst<Arc> I1(T); Project(&P1, PROJECT_OUTPUT); Invert(&I1); Project(&I1, PROJECT_INPUT); CHECK(Equiv(P1, I1)); } { VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed)."; ProjectFst<Arc> P1(T, PROJECT_INPUT); InvertFst<Arc> I1(T); ProjectFst<Arc> P2(I1, PROJECT_OUTPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed)."; ProjectFst<Arc> P1(T, PROJECT_OUTPUT); InvertFst<Arc> I1(T); ProjectFst<Arc> P2(I1, PROJECT_INPUT); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check destructive relabeling"; static const int kNumLabels = 10; // set up relabeling pairs vector<Label> labelset(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i; for (size_t i = 0; i < kNumLabels; ++i) { swap(labelset[i], labelset[rand() % kNumLabels]); } vector<pair<Label, Label> > ipairs1(kNumLabels); vector<pair<Label, Label> > opairs1(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) { ipairs1[i] = make_pair(i, labelset[i]); opairs1[i] = make_pair(labelset[i], i); } VectorFst<Arc> R(T); Relabel(&R, ipairs1, opairs1); vector<pair<Label, Label> > ipairs2(kNumLabels); vector<pair<Label, Label> > opairs2(kNumLabels); for (size_t i = 0; i < kNumLabels; ++i) { ipairs2[i] = make_pair(labelset[i], i); opairs2[i] = make_pair(i, labelset[i]); } Relabel(&R, ipairs2, opairs2); CHECK(Equiv(R, T)); VLOG(1) << "Check on-the-fly relabeling"; RelabelFst<Arc> Rdelay(T, ipairs1, opairs1); RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2); CHECK(Equiv(RRdelay, T)); } { VLOG(1) << "Check encoding/decoding (destructive)."; VectorFst<Arc> D(T); uint32 encode_props = 0; if (rand() % 2) encode_props |= kEncodeLabels; if (rand() % 2) encode_props |= kEncodeWeights; EncodeMapper<Arc> encoder(encode_props, ENCODE); Encode(&D, &encoder); Decode(&D, encoder); CHECK(Equiv(D, T)); } { VLOG(1) << "Check encoding/decoding (delayed)."; uint32 encode_props = 0; if (rand() % 2) encode_props |= kEncodeLabels; if (rand() % 2) encode_props |= kEncodeWeights; EncodeMapper<Arc> encoder(encode_props, ENCODE); EncodeFst<Arc> E(T, &encoder); VectorFst<Arc> Encoded(E); DecodeFst<Arc> D(Encoded, encoder); CHECK(Equiv(D, T)); } { VLOG(1) << "Check gallic mappers (constructive)."; ToGallicMapper<Arc> to_mapper; FromGallicMapper<Arc> from_mapper; VectorFst< GallicArc<Arc> > G; VectorFst<Arc> F; ArcMap(T, &G, to_mapper); ArcMap(G, &F, from_mapper); CHECK(Equiv(T, F)); } { VLOG(1) << "Check gallic mappers (delayed)."; ToGallicMapper<Arc> to_mapper; FromGallicMapper<Arc> from_mapper; ArcMapFst<Arc, GallicArc<Arc>, ToGallicMapper<Arc> > G(T, to_mapper); ArcMapFst<GallicArc<Arc>, Arc, FromGallicMapper<Arc> > F(G, from_mapper); CHECK(Equiv(T, F)); } } // Tests compose-based operations. void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) { if (!(Weight::Properties() & kCommutative)) return; VectorFst<Arc> S1(T1); VectorFst<Arc> S2(T2); VectorFst<Arc> S3(T3); ILabelCompare<Arc> icomp; OLabelCompare<Arc> ocomp; ArcSort(&S1, ocomp); ArcSort(&S2, ocomp); ArcSort(&S3, icomp); { VLOG(1) << "Check composition is associative."; ComposeFst<Arc> C1(S1, S2); ComposeFst<Arc> C2(C1, S3); ComposeFst<Arc> C3(S2, S3); ComposeFst<Arc> C4(S1, C3); CHECK(Equiv(C2, C4)); } { VLOG(1) << "Check composition left distributes over union."; UnionFst<Arc> U1(S2, S3); ComposeFst<Arc> C1(S1, U1); ComposeFst<Arc> C2(S1, S2); ComposeFst<Arc> C3(S1, S3); UnionFst<Arc> U2(C2, C3); CHECK(Equiv(C1, U2)); } { VLOG(1) << "Check composition right distributes over union."; UnionFst<Arc> U1(S1, S2); ComposeFst<Arc> C1(U1, S3); ComposeFst<Arc> C2(S1, S3); ComposeFst<Arc> C3(S2, S3); UnionFst<Arc> U2(C2, C3); CHECK(Equiv(C1, U2)); } VectorFst<Arc> A1(S1); VectorFst<Arc> A2(S2); VectorFst<Arc> A3(S3); Project(&A1, PROJECT_OUTPUT); Project(&A2, PROJECT_INPUT); Project(&A3, PROJECT_INPUT); { VLOG(1) << "Check intersection is commutative."; IntersectFst<Arc> I1(A1, A2); IntersectFst<Arc> I2(A2, A1); CHECK(Equiv(I1, I2)); } { VLOG(1) << "Check all epsilon filters leads to equivalent results."; typedef Matcher< Fst<Arc> > M; ComposeFst<Arc> C1(S1, S2); ComposeFst<Arc> C2( S1, S2, ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M> >()); ComposeFst<Arc> C3( S1, S2, ComposeFstOptions<Arc, M, MatchComposeFilter<M> >()); CHECK(Equiv(C1, C2)); CHECK(Equiv(C1, C3)); } } // Tests sorting operations void TestSort(const Fst<Arc> &T) { ILabelCompare<Arc> icomp; OLabelCompare<Arc> ocomp; { VLOG(1) << "Check arc sorted Fst is equivalent to its input."; VectorFst<Arc> S1(T); ArcSort(&S1, icomp); CHECK(Equiv(T, S1)); } { VLOG(1) << "Check destructive and delayed arcsort are equivalent."; VectorFst<Arc> S1(T); ArcSort(&S1, icomp); ArcSortFst< Arc, ILabelCompare<Arc> > S2(T, icomp); CHECK(Equiv(S1, S2)); } { VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions."; VectorFst<Arc> S1(T); VectorFst<Arc> S2(T); ArcSort(&S1, icomp); Invert(&S2); ArcSort(&S2, ocomp); Invert(&S2); CHECK(Equiv(S1, S2)); } { VLOG(1) << "Check topologically sorted Fst is equivalent to its input."; VectorFst<Arc> S1(T); TopSort(&S1); CHECK(Equiv(T, S1)); } { VLOG(1) << "Check reverse(reverse(T)) = T"; VectorFst< ReverseArc<Arc> > R1; VectorFst<Arc> R2; Reverse(T, &R1); Reverse(R1, &R2); CHECK(Equiv(T, R2)); } } // Tests optimization operations void TestOptimize(const Fst<Arc> &T) { uint64 tprops = T.Properties(kFstProperties, true); uint64 wprops = Weight::Properties(); VectorFst<Arc> A(T); Project(&A, PROJECT_INPUT); { VLOG(1) << "Check connected FST is equivalent to its input."; VectorFst<Arc> C1(T); Connect(&C1); CHECK(Equiv(T, C1)); } if ((wprops & kSemiring) == kSemiring && (tprops & kAcyclic || wprops & kIdempotent)) { VLOG(1) << "Check epsilon-removed FST is equivalent to its input."; VectorFst<Arc> R1(T); RmEpsilon(&R1); CHECK(Equiv(T, R1)); VLOG(1) << "Check destructive and delayed epsilon removal" << "are equivalent."; RmEpsilonFst<Arc> R2(T); CHECK(Equiv(R1, R2)); VLOG(1) << "Check an FST with a large proportion" << " of epsilon transitions:"; // Maps all transitions of T to epsilon-transitions and append // a non-epsilon transition. VectorFst<Arc> U; ArcMap(T, &U, EpsMapper<Arc>()); VectorFst<Arc> V; V.SetStart(V.AddState()); Arc arc(1, 1, Weight::One(), V.AddState()); V.AddArc(V.Start(), arc); V.SetFinal(arc.nextstate, Weight::One()); Concat(&U, V); // Check that epsilon-removal preserves the shortest-distance // from the initial state to the final states. vector<Weight> d; ShortestDistance(U, &d, true); Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero(); VectorFst<Arc> U1(U); RmEpsilon(&U1); ShortestDistance(U1, &d, true); Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero(); CHECK(ApproxEqual(w, w1, kTestDelta)); RmEpsilonFst<Arc> U2(U); ShortestDistance(U2, &d, true); Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero(); CHECK(ApproxEqual(w, w2, kTestDelta)); } if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) { VLOG(1) << "Check determinized FSA is equivalent to its input."; DeterminizeFst<Arc> D(A); CHECK(Equiv(A, D)); int n; { VLOG(1) << "Check size(min(det(A))) <= size(det(A))" << " and min(det(A)) equiv det(A)"; VectorFst<Arc> M(D); n = M.NumStates(); Minimize(&M); CHECK(Equiv(D, M)); CHECK(M.NumStates() <= n); n = M.NumStates(); } if (n && (wprops & kIdempotent) == kIdempotent && A.Properties(kNoEpsilons, true)) { VLOG(1) << "Check that Revuz's algorithm leads to the" << " same number of states as Brozozowski's algorithm"; // Skip test if A is the empty machine or contains epsilons or // if the semiring is not idempotent (to avoid floating point // errors) VectorFst<Arc> R; Reverse(A, &R); RmEpsilon(&R); DeterminizeFst<Arc> DR(R); VectorFst<Arc> RD; Reverse(DR, &RD); DeterminizeFst<Arc> DRD(RD); VectorFst<Arc> M(DRD); CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition // to the initial state } } if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) { VLOG(1) << "Check reweight(T) equiv T"; vector<Weight> potential; VectorFst<Arc> RI(T); VectorFst<Arc> RF(T); while (potential.size() < RI.NumStates()) potential.push_back((*weight_generator_)()); Reweight(&RI, potential, REWEIGHT_TO_INITIAL); CHECK(Equiv(T, RI)); Reweight(&RF, potential, REWEIGHT_TO_FINAL); CHECK(Equiv(T, RF)); } if ((wprops & kIdempotent) || (tprops & kAcyclic)) { VLOG(1) << "Check pushed FST is equivalent to input FST."; // Pushing towards the final state. if (wprops & kRightSemiring) { VectorFst<Arc> P1; Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels); CHECK(Equiv(T, P1)); VectorFst<Arc> P2; Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights); CHECK(Equiv(T, P2)); VectorFst<Arc> P3; Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights); CHECK(Equiv(T, P3)); } // Pushing towards the initial state. if (wprops & kLeftSemiring) { VectorFst<Arc> P1; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels); CHECK(Equiv(T, P1)); VectorFst<Arc> P2; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights); CHECK(Equiv(T, P2)); VectorFst<Arc> P3; Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights); CHECK(Equiv(T, P3)); } } if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) { VLOG(1) << "Check pruning algorithm"; { VLOG(1) << "Check equiv. of constructive and destructive algorithms"; Weight thresold = (*weight_generator_)(); VectorFst<Arc> P1(T); Prune(&P1, thresold); VectorFst<Arc> P2; Prune(T, &P2, thresold); CHECK(Equiv(P1,P2)); } { VLOG(1) << "Check prune(reverse) equiv reverse(prune)"; Weight thresold = (*weight_generator_)(); VectorFst< ReverseArc<Arc> > R; VectorFst<Arc> P1(T); VectorFst<Arc> P2; Prune(&P1, thresold); Reverse(T, &R); Prune(&R, thresold.Reverse()); Reverse(R, &P2); CHECK(Equiv(P1, P2)); } { VLOG(1) << "Check: ShortestDistance(T- prune(T))" << " > ShortestDistance(T) times Thresold"; Weight thresold = (*weight_generator_)(); VectorFst<Arc> P; Prune(A, &P, thresold); DifferenceFst<Arc> C(A, DeterminizeFst<Arc> (RmEpsilonFst<Arc> (ArcMapFst<Arc, Arc, RmWeightMapper<Arc> > (P, RmWeightMapper<Arc>())))); Weight sum1 = Times(ShortestDistance(A), thresold); Weight sum2 = ShortestDistance(C); CHECK(Plus(sum1, sum2) == sum1); } } if (tprops & kAcyclic) { VLOG(1) << "Check synchronize(T) equiv T"; SynchronizeFst<Arc> S(T); CHECK(Equiv(T, S)); } } // Tests search operations void TestSearch(const Fst<Arc> &T) { uint64 wprops = Weight::Properties(); VectorFst<Arc> A(T); Project(&A, PROJECT_INPUT); if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) { VLOG(1) << "Check 1-best weight."; VectorFst<Arc> path; ShortestPath(T, &path); Weight tsum = ShortestDistance(T); Weight psum = ShortestDistance(path); CHECK(ApproxEqual(tsum, psum, kTestDelta)); } if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) { VLOG(1) << "Check n-best weights"; VectorFst<Arc> R(A); RmEpsilon(&R); int nshortest = rand() % kNumRandomShortestPaths + 2; VectorFst<Arc> paths; ShortestPath(R, &paths, nshortest, true, false, Weight::Zero(), kNumShortestStates); vector<Weight> distance; ShortestDistance(paths, &distance, true); StateId pstart = paths.Start(); if (pstart != kNoStateId) { ArcIterator< Fst<Arc> > piter(paths, pstart); for (; !piter.Done(); piter.Next()) { StateId s = piter.Value().nextstate; Weight nsum = s < distance.size() ? Times(piter.Value().weight, distance[s]) : Weight::Zero(); VectorFst<Arc> path; ShortestPath(R, &path); Weight dsum = ShortestDistance(path); CHECK(ApproxEqual(nsum, dsum, kTestDelta)); ArcMap(&path, RmWeightMapper<Arc>()); VectorFst<Arc> S; Difference(R, path, &S); R = S; } } } } // Tests if two FSTS are equivalent by checking if random // strings from one FST are transduced the same by both FSTs. bool Equiv(const Fst<Arc> &fst1, const Fst<Arc> &fst2) { VLOG(1) << "Check FSTs for sanity (including property bits)."; CHECK(Verify(fst1)); CHECK(Verify(fst2)); UniformArcSelector<Arc> uniform_selector(seed_); RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector, kRandomPathLength); return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts); } // Random seed int seed_; // FST with no states VectorFst<Arc> zero_fst_; // FST with one state that accepts epsilon. VectorFst<Arc> one_fst_; // FST with one state that accepts all strings. VectorFst<Arc> univ_fst_; // Generates weights used in testing. WeightGenerator *weight_generator_; // Maximum random path length. static const int kRandomPathLength; // Number of random paths to explore. static const int kNumRandomPaths; // Maximum number of nshortest paths. static const int kNumRandomShortestPaths; // Maximum number of nshortest states. static const int kNumShortestStates; // Delta for equivalence tests. static const float kTestDelta; DISALLOW_COPY_AND_ASSIGN(WeightedTester); }; template <class A, class WG> const int WeightedTester<A, WG>::kRandomPathLength = 25; template <class A, class WG> const int WeightedTester<A, WG>::kNumRandomPaths = 100; template <class A, class WG> const int WeightedTester<A, WG>::kNumRandomShortestPaths = 100; template <class A, class WG> const int WeightedTester<A, WG>::kNumShortestStates = 10000; template <class A, class WG> const float WeightedTester<A, WG>::kTestDelta = .05; // This class tests a variety of identities and properties that must // hold for various algorithms on unweighted FSAs and that are not tested // by WeightedTester. Only the specialization does anything interesting. template <class Arc> class UnweightedTester { public: UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa, const Fst<Arc> &univ_fsa) {} void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {} }; // Specialization for StdArc. This should work for any commutative, // idempotent semiring when restricted to the unweighted case // (being isomorphic to the boolean semiring). template <> class UnweightedTester<StdArc> { public: typedef StdArc Arc; typedef Arc::Label Label; typedef Arc::StateId StateId; typedef Arc::Weight Weight; UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa, const Fst<Arc> &univ_fsa) : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {} void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { TestRational(A1, A2, A3); TestIntersect(A1, A2, A3); TestOptimize(A1); } private: // Tests rational operations with identities void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { { VLOG(1) << "Check the union contains its arguments (destructive)."; VectorFst<Arc> U(A1); Union(&U, A2); CHECK(Subset(A1, U)); CHECK(Subset(A2, U)); } { VLOG(1) << "Check the union contains its arguments (delayed)."; UnionFst<Arc> U(A1, A2); CHECK(Subset(A1, U)); CHECK(Subset(A2, U)); } { VLOG(1) << "Check if A^n c A* (destructive)."; VectorFst<Arc> C(one_fsa_); int n = rand() % 5; for (int i = 0; i < n; ++i) Concat(&C, A1); VectorFst<Arc> S(A1); Closure(&S, CLOSURE_STAR); CHECK(Subset(C, S)); } { VLOG(1) << "Check if A^n c A* (delayed)."; int n = rand() % 5; Fst<Arc> *C = new VectorFst<Arc>(one_fsa_); for (int i = 0; i < n; ++i) { ConcatFst<Arc> *F = new ConcatFst<Arc>(*C, A1); delete C; C = F; } ClosureFst<Arc> S(A1, CLOSURE_STAR); CHECK(Subset(*C, S)); delete C; } } // Tests intersect-based operations. void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) { VectorFst<Arc> S1(A1); VectorFst<Arc> S2(A2); VectorFst<Arc> S3(A3); ILabelCompare<Arc> comp; ArcSort(&S1, comp); ArcSort(&S2, comp); ArcSort(&S3, comp); { VLOG(1) << "Check the intersection is contained in its arguments."; IntersectFst<Arc> I1(S1, S2); CHECK(Subset(I1, S1)); CHECK(Subset(I1, S2)); } { VLOG(1) << "Check union distributes over intersection."; IntersectFst<Arc> I1(S1, S2); UnionFst<Arc> U1(I1, S3); UnionFst<Arc> U2(S1, S3); UnionFst<Arc> U3(S2, S3); ArcSortFst< Arc, ILabelCompare<Arc> > S4(U3, comp); IntersectFst<Arc> I2(U2, S4); CHECK(Equiv(U1, I2)); } VectorFst<Arc> C1; VectorFst<Arc> C2; Complement(S1, &C1); Complement(S2, &C2); ArcSort(&C1, comp); ArcSort(&C2, comp); { VLOG(1) << "Check S U S' = Sigma*"; UnionFst<Arc> U(S1, C1); CHECK(Equiv(U, univ_fsa_)); } { VLOG(1) << "Check S n S' = {}"; IntersectFst<Arc> I(S1, C1); CHECK(Equiv(I, zero_fsa_)); } { VLOG(1) << "Check (S1' U S2') == (S1 n S2)'"; UnionFst<Arc> U(C1, C2); IntersectFst<Arc> I(S1, S2); VectorFst<Arc> C3; Complement(I, &C3); CHECK(Equiv(U, C3)); } { VLOG(1) << "Check (S1' n S2') == (S1 U S2)'"; IntersectFst<Arc> I(C1, C2); UnionFst<Arc> U(S1, S2); VectorFst<Arc> C3; Complement(U, &C3); CHECK(Equiv(I, C3)); } } // Tests optimization operations void TestOptimize(const Fst<Arc> &A) { { VLOG(1) << "Check determinized FSA is equivalent to its input."; DeterminizeFst<Arc> D(A); CHECK(Equiv(A, D)); } { VLOG(1) << "Check minimized FSA is equivalent to its input."; int n; { RmEpsilonFst<Arc> R(A); DeterminizeFst<Arc> D(R); VectorFst<Arc> M(D); Minimize(&M); CHECK(Equiv(A, M)); n = M.NumStates(); } if (n) { // Skip test if A is the empty machine VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the" << " same number of states as Brozozowski's algorithm"; VectorFst<Arc> R; Reverse(A, &R); RmEpsilon(&R); DeterminizeFst<Arc> DR(R); VectorFst<Arc> RD; Reverse(DR, &RD); DeterminizeFst<Arc> DRD(RD); VectorFst<Arc> M(DRD); CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition // to the initial state } } } // Tests if two FSAS are equivalent. bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) { VLOG(1) << "Check FSAs for sanity (including property bits)."; CHECK(Verify(fsa1)); CHECK(Verify(fsa2)); VectorFst<Arc> vfsa1(fsa1); VectorFst<Arc> vfsa2(fsa2); RmEpsilon(&vfsa1); RmEpsilon(&vfsa2); DeterminizeFst<Arc> dfa1(vfsa1); DeterminizeFst<Arc> dfa2(vfsa2); // Test equivalence using union-find algorithm bool equiv1 = Equivalent(dfa1, dfa2); // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty ILabelCompare<Arc> comp; VectorFst<Arc> sdfa1(dfa1); ArcSort(&sdfa1, comp); VectorFst<Arc> sdfa2(dfa2); ArcSort(&sdfa2, comp); DifferenceFst<Arc> dfsa1(sdfa1, sdfa2); DifferenceFst<Arc> dfsa2(sdfa2, sdfa1); VectorFst<Arc> ufsa(dfsa1); Union(&ufsa, dfsa2); Connect(&ufsa); bool equiv2 = ufsa.NumStates() == 0; // Check two equivalence tests match CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2)); return equiv1; } // Tests if FSA1 is a subset of FSA2 (disregarding weights). bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) { VLOG(1) << "Check FSAs (incl. property bits) for sanity"; CHECK(Verify(fsa1)); CHECK(Verify(fsa2)); VectorFst<StdArc> vfsa1; VectorFst<StdArc> vfsa2; RmEpsilon(&vfsa1); RmEpsilon(&vfsa2); ILabelCompare<StdArc> comp; ArcSort(&vfsa1, comp); ArcSort(&vfsa2, comp); IntersectFst<StdArc> ifsa(vfsa1, vfsa2); DeterminizeFst<StdArc> dfa1(vfsa1); DeterminizeFst<StdArc> dfa2(ifsa); return Equivalent(dfa1, dfa2); } // Returns complement Fsa void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) { RmEpsilonFst<Arc> rfsa(ifsa); DeterminizeFst<Arc> dfa(rfsa); DifferenceFst<Arc> cfsa(univ_fsa_, dfa); *ofsa = cfsa; } // FSA with no states VectorFst<Arc> zero_fsa_; // FSA with one state that accepts epsilon. VectorFst<Arc> one_fsa_; // FSA with one state that accepts all strings. VectorFst<Arc> univ_fsa_; DISALLOW_COPY_AND_ASSIGN(UnweightedTester); }; // This class tests a variety of identities and properties that must // hold for various FST algorithms. It randomly generates FSTs, using // function object 'weight_generator' to select weights. 'WeightTester' // and 'UnweightedTester' are then called. template <class Arc, class WeightGenerator> class AlgoTester { public: typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; AlgoTester(WeightGenerator generator, int seed) : weight_generator_(generator), seed_(seed) { one_fst_.AddState(); one_fst_.SetStart(0); one_fst_.SetFinal(0, Weight::One()); univ_fst_.AddState(); univ_fst_.SetStart(0); univ_fst_.SetFinal(0, Weight::One()); for (int i = 0; i < kNumRandomLabels; ++i) univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0)); } void Test() { VLOG(1) << "weight type = " << Weight::Type(); for (int i = 0; i < FLAGS_repeat; ++i) { // Random transducers VectorFst<Arc> T1; VectorFst<Arc> T2; VectorFst<Arc> T3; RandFst(&T1); RandFst(&T2); RandFst(&T3); WeightedTester<Arc, WeightGenerator> weighted_tester(seed_, zero_fst_, one_fst_, univ_fst_, &weight_generator_); weighted_tester.Test(T1, T2, T3); VectorFst<Arc> A1(T1); VectorFst<Arc> A2(T2); VectorFst<Arc> A3(T3); Project(&A1, PROJECT_OUTPUT); Project(&A2, PROJECT_INPUT); Project(&A3, PROJECT_INPUT); ArcMap(&A1, rm_weight_mapper); ArcMap(&A2, rm_weight_mapper); ArcMap(&A3, rm_weight_mapper); UnweightedTester<Arc> unweighted_tester(zero_fst_, one_fst_, univ_fst_); unweighted_tester.Test(A1, A2, A3); } } private: // Generates a random FST. void RandFst(MutableFst<Arc> *fst) { // Determines direction of the arcs wrt state numbering. This way we // can force acyclicity when desired. enum ArcDirection { ANY_DIRECTION = 0, FORWARD_DIRECTION = 1, REVERSE_DIRECTION = 2, NUM_DIRECTIONS = 3 }; ArcDirection arc_direction = ANY_DIRECTION; if (rand()/(RAND_MAX + 1.0) < kAcyclicProb) arc_direction = rand() % 2 ? FORWARD_DIRECTION : REVERSE_DIRECTION; fst->DeleteStates(); StateId ns = rand() % kNumRandomStates; if (ns == 0) return; for (StateId s = 0; s < ns; ++s) fst->AddState(); StateId start = rand() % ns; fst->SetStart(start); size_t na = rand() % kNumRandomArcs; for (size_t n = 0; n < na; ++n) { StateId s = rand() % ns; Arc arc; arc.ilabel = rand() % kNumRandomLabels; arc.olabel = rand() % kNumRandomLabels; arc.weight = weight_generator_(); arc.nextstate = rand() % ns; if (arc_direction == ANY_DIRECTION || (arc_direction == FORWARD_DIRECTION && arc.ilabel > arc.olabel) || (arc_direction == REVERSE_DIRECTION && arc.ilabel < arc.olabel)) fst->AddArc(s, arc); } StateId nf = rand() % (ns + 1); for (StateId n = 0; n < nf; ++n) { StateId s = rand() % ns; Weight final = weight_generator_(); fst->SetFinal(s, final); } VLOG(1) << "Check FST for sanity (including property bits)."; CHECK(Verify(*fst)); // Get/compute all properties. uint64 props = fst->Properties(kFstProperties, true); // Select random set of properties to be unknown. uint64 mask = 0; for (int n = 0; n < 8; ++n) { mask |= rand() & 0xff; mask <<= 8; } mask &= ~kTrinaryProperties; fst->SetProperties(props & ~mask, mask); } // Generates weights used in testing. WeightGenerator weight_generator_; // Random seed int seed_; // FST with no states VectorFst<Arc> zero_fst_; // FST with one state that accepts epsilon. VectorFst<Arc> one_fst_; // FST with one state that accepts all strings. VectorFst<Arc> univ_fst_; // Mapper to remove weights from an Fst RmWeightMapper<Arc> rm_weight_mapper; // Maximum number of states in random test Fst. static const int kNumRandomStates; // Maximum number of arcs in random test Fst. static const int kNumRandomArcs; // Number of alternative random labels. static const int kNumRandomLabels; // Probability to force an acyclic Fst static const float kAcyclicProb; // Maximum random path length. static const int kRandomPathLength; // Number of random paths to explore. static const int kNumRandomPaths; DISALLOW_COPY_AND_ASSIGN(AlgoTester); }; template <class A, class G> const int AlgoTester<A, G>::kNumRandomStates = 10; template <class A, class G> const int AlgoTester<A, G>::kNumRandomArcs = 25; template <class A, class G> const int AlgoTester<A, G>::kNumRandomLabels = 5; template <class A, class G> const float AlgoTester<A, G>::kAcyclicProb = .25; template <class A, class G> const int AlgoTester<A, G>::kRandomPathLength = 25; template <class A, class G> const int AlgoTester<A, G>::kNumRandomPaths = 100; } // namespace fst #endif // FST_TEST_ALGO_TEST_H__