summaryrefslogtreecommitdiff
path: root/chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc')
-rw-r--r--chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc226
1 files changed, 226 insertions, 0 deletions
diff --git a/chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc b/chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc
new file mode 100644
index 00000000000..358fe1b8ff5
--- /dev/null
+++ b/chromium/third_party/cld_3/src/src/nnet_lang_id_test.cc
@@ -0,0 +1,226 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <iostream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "base.h"
+#include "nnet_lang_id_test_data.h"
+#include "nnet_language_identifier.h"
+
+namespace chrome_lang_id {
+namespace nnet_lang_id_test {
+
+// Tests the model on all supported languages. Returns "true" if the test is
+// successful and "false" otherwise.
+// TODO(abakalov): Add a test for random input that should be labeled as
+// "unknown" due to low confidence.
+bool TestPredictions() {
+ std::cout << "Running " << __FUNCTION__ << std::endl;
+
+ // (gold language, sample text) pairs used for testing.
+ const std::vector<std::pair<std::string, std::string>> gold_lang_text = {
+ {"af", NNetLangIdTestData::kTestStrAF},
+ {"ar", NNetLangIdTestData::kTestStrAR},
+ {"az", NNetLangIdTestData::kTestStrAZ},
+ {"be", NNetLangIdTestData::kTestStrBE},
+ {"bg", NNetLangIdTestData::kTestStrBG},
+ {"bn", NNetLangIdTestData::kTestStrBN},
+ {"bs", NNetLangIdTestData::kTestStrBS},
+ {"ca", NNetLangIdTestData::kTestStrCA},
+ {"ceb", NNetLangIdTestData::kTestStrCEB},
+ {"cs", NNetLangIdTestData::kTestStrCS},
+ {"cy", NNetLangIdTestData::kTestStrCY},
+ {"da", NNetLangIdTestData::kTestStrDA},
+ {"de", NNetLangIdTestData::kTestStrDE},
+ {"el", NNetLangIdTestData::kTestStrEL},
+ {"en", NNetLangIdTestData::kTestStrEN},
+ {"eo", NNetLangIdTestData::kTestStrEO},
+ {"es", NNetLangIdTestData::kTestStrES},
+ {"et", NNetLangIdTestData::kTestStrET},
+ {"eu", NNetLangIdTestData::kTestStrEU},
+ {"fa", NNetLangIdTestData::kTestStrFA},
+ {"fi", NNetLangIdTestData::kTestStrFI},
+ {"fil", NNetLangIdTestData::kTestStrFIL},
+ {"fr", NNetLangIdTestData::kTestStrFR},
+ {"ga", NNetLangIdTestData::kTestStrGA},
+ {"gl", NNetLangIdTestData::kTestStrGL},
+ {"gu", NNetLangIdTestData::kTestStrGU},
+ {"ha", NNetLangIdTestData::kTestStrHA},
+ {"hi", NNetLangIdTestData::kTestStrHI},
+ {"hmn", NNetLangIdTestData::kTestStrHMN},
+ {"hr", NNetLangIdTestData::kTestStrHR},
+ {"ht", NNetLangIdTestData::kTestStrHT},
+ {"hu", NNetLangIdTestData::kTestStrHU},
+ {"hy", NNetLangIdTestData::kTestStrHY},
+ {"id", NNetLangIdTestData::kTestStrID},
+ {"ig", NNetLangIdTestData::kTestStrIG},
+ {"is", NNetLangIdTestData::kTestStrIS},
+ {"it", NNetLangIdTestData::kTestStrIT},
+ {"iw", NNetLangIdTestData::kTestStrIW},
+ {"ja", NNetLangIdTestData::kTestStrJA},
+ {"jv", NNetLangIdTestData::kTestStrJV},
+ {"ka", NNetLangIdTestData::kTestStrKA},
+ {"kk", NNetLangIdTestData::kTestStrKK},
+ {"km", NNetLangIdTestData::kTestStrKM},
+ {"kn", NNetLangIdTestData::kTestStrKN},
+ {"ko", NNetLangIdTestData::kTestStrKO},
+ {"la", NNetLangIdTestData::kTestStrLA},
+ {"lo", NNetLangIdTestData::kTestStrLO},
+ {"lt", NNetLangIdTestData::kTestStrLT},
+ {"lv", NNetLangIdTestData::kTestStrLV},
+ {"mg", NNetLangIdTestData::kTestStrMG},
+ {"mi", NNetLangIdTestData::kTestStrMI},
+ {"mk", NNetLangIdTestData::kTestStrMK},
+ {"ml", NNetLangIdTestData::kTestStrML},
+ {"mn", NNetLangIdTestData::kTestStrMN},
+ {"mr", NNetLangIdTestData::kTestStrMR},
+ {"ms", NNetLangIdTestData::kTestStrMS},
+ {"mt", NNetLangIdTestData::kTestStrMT},
+ {"my", NNetLangIdTestData::kTestStrMY},
+ {"ne", NNetLangIdTestData::kTestStrNE},
+ {"nl", NNetLangIdTestData::kTestStrNL},
+ {"no", NNetLangIdTestData::kTestStrNO},
+ {"ny", NNetLangIdTestData::kTestStrNY},
+ {"pa", NNetLangIdTestData::kTestStrPA},
+ {"pl", NNetLangIdTestData::kTestStrPL},
+ {"pt", NNetLangIdTestData::kTestStrPT},
+ {"ro", NNetLangIdTestData::kTestStrRO},
+ {"ru", NNetLangIdTestData::kTestStrRU},
+ {"si", NNetLangIdTestData::kTestStrSI},
+ {"sk", NNetLangIdTestData::kTestStrSK},
+ {"sl", NNetLangIdTestData::kTestStrSL},
+ {"so", NNetLangIdTestData::kTestStrSO},
+ {"sq", NNetLangIdTestData::kTestStrSQ},
+ {"sr", NNetLangIdTestData::kTestStrSR},
+ {"st", NNetLangIdTestData::kTestStrST},
+ {"su", NNetLangIdTestData::kTestStrSU},
+ {"sv", NNetLangIdTestData::kTestStrSV},
+ {"sw", NNetLangIdTestData::kTestStrSW},
+ {"ta", NNetLangIdTestData::kTestStrTA},
+ {"te", NNetLangIdTestData::kTestStrTE},
+ {"tg", NNetLangIdTestData::kTestStrTG},
+ {"th", NNetLangIdTestData::kTestStrTH},
+ {"tr", NNetLangIdTestData::kTestStrTR},
+ {"uk", NNetLangIdTestData::kTestStrUK},
+ {"ur", NNetLangIdTestData::kTestStrUR},
+ {"uz", NNetLangIdTestData::kTestStrUZ},
+ {"vi", NNetLangIdTestData::kTestStrVI},
+ {"yi", NNetLangIdTestData::kTestStrYI},
+ {"yo", NNetLangIdTestData::kTestStrYO},
+ {"zh", NNetLangIdTestData::kTestStrZH},
+ {"zu", NNetLangIdTestData::kTestStrZU}};
+
+ NNetLanguageIdentifier lang_id(/*min_num_bytes=*/0,
+ /*max_num_bytes=*/1000);
+
+ // Iterate over all the test instances, make predictions and check that they
+ // are correct.
+ int num_wrong = 0;
+ for (const auto &test_instance : gold_lang_text) {
+ const std::string &expected_lang = test_instance.first;
+ const std::string &text = test_instance.second;
+
+ const NNetLanguageIdentifier::Result result = lang_id.FindLanguage(text);
+ if (result.language != expected_lang) {
+ ++num_wrong;
+ std::cout << " Misclassification: " << std::endl;
+ std::cout << " Text: " << text << std::endl;
+ std::cout << " Expected language: " << expected_lang << std::endl;
+ std::cout << " Predicted language: " << result.language << std::endl;
+ }
+ }
+
+ if (num_wrong == 0) {
+ std::cout << " Success!" << std::endl;
+ return true;
+ } else {
+ std::cout << " Failure: " << num_wrong << " wrong predictions"
+ << std::endl;
+ return false;
+ }
+}
+
+// Tests the model on input containing multiple languages of different scripts.
+// Returns "true" if the test is successful and "false" otherwise.
+bool TestMultipleLanguagesInInput() {
+ std::cout << "Running " << __FUNCTION__ << std::endl;
+
+ // Text containing snippets in English and Bulgarian.
+ const std::string text =
+ "This piece of text is in English. Този текст е на Български.";
+
+ // Expected language spans in the input text, corresponding respectively to
+ // Bulgarian and English.
+ const std::string expected_bg_span = " Този текст е на Български ";
+ const std::string expected_en_span = " This piece of text is in English ";
+ const float expected_byte_sum =
+ static_cast<float>(expected_bg_span.size() + expected_en_span.size());
+
+ // Number of languages to query for and the expected byte proportions.
+ const int num_queried_langs = 3;
+ const std::unordered_map<string, float> expected_lang_proportions{
+ {"bg", expected_bg_span.size() / expected_byte_sum},
+ {"en", expected_en_span.size() / expected_byte_sum},
+ {NNetLanguageIdentifier::kUnknown, 0.0}};
+
+ NNetLanguageIdentifier lang_id(/*min_num_bytes=*/0,
+ /*max_num_bytes=*/1000);
+ const std::vector<NNetLanguageIdentifier::Result> results =
+ lang_id.FindTopNMostFreqLangs(text, num_queried_langs);
+
+ if (results.size() != expected_lang_proportions.size()) {
+ std::cout << " Failure" << std::endl;
+ std::cout << " Wrong number of languages: expected "
+ << expected_lang_proportions.size() << ", obtained "
+ << results.size() << std::endl;
+ return false;
+ }
+
+ // Iterate over the results and check that the correct proportions are
+ // returned for the expected languages.
+ const float epsilon = 0.00001f;
+ for (const NNetLanguageIdentifier::Result &result : results) {
+ if (expected_lang_proportions.count(result.language) == 0) {
+ std::cout << " Failure" << std::endl;
+ std::cout << " Incorrect language: " << result.language << std::endl;
+ return false;
+ }
+ if (std::abs(result.proportion -
+ expected_lang_proportions.at(result.language)) > epsilon) {
+ std::cout << " Failure" << std::endl;
+ std::cout << " Language " << result.language << ": expected proportion "
+ << expected_lang_proportions.at(result.language) << ", got "
+ << result.proportion << std::endl;
+ return false;
+ }
+ }
+ std::cout << " Success!" << std::endl;
+ return true;
+}
+
+} // namespace nnet_lang_id_test
+} // namespace chrome_lang_id
+
+// Runs tests for the language identification model.
+int main(int argc, char **argv) {
+ const bool tests_successful =
+ chrome_lang_id::nnet_lang_id_test::TestPredictions() &&
+ chrome_lang_id::nnet_lang_id_test::TestMultipleLanguagesInInput();
+ return tests_successful ? 0 : 1;
+}