diff --git a/components/web_discovery/browser/BUILD.gn b/components/web_discovery/browser/BUILD.gn index 88d0f6b3d2b0..cd204702f1b1 100644 --- a/components/web_discovery/browser/BUILD.gn +++ b/components/web_discovery/browser/BUILD.gn @@ -15,6 +15,10 @@ static_library("browser") { "credential_manager.h", "credential_signer.cc", "credential_signer.h", + "double_fetcher.cc", + "double_fetcher.h", + "ecdh_aes.cc", + "ecdh_aes.h", "hash_detection.cc", "hash_detection.h", "patterns.cc", @@ -26,10 +30,16 @@ static_library("browser") { "privacy_guard.h", "regex_util.cc", "regex_util.h", + "reporter.cc", + "reporter.h", + "request_queue.cc", + "request_queue.h", "rsa.cc", "rsa.h", "server_config_loader.cc", "server_config_loader.h", + "signature_basename.cc", + "signature_basename.h", "util.cc", "util.h", "web_discovery_service.cc", @@ -63,11 +73,14 @@ source_set("unit_tests") { testonly = true sources = [ "credential_manager_unittest.cc", + "double_fetcher_unittest.cc", "hash_detection_unittest.cc", "patterns_unittest.cc", "payload_generator_unittest.cc", "privacy_guard_unittest.cc", + "reporter_unittest.cc", "server_config_loader_unittest.cc", + "signature_basename_unittest.cc", ] deps = [ ":browser", diff --git a/components/web_discovery/browser/double_fetcher.cc b/components/web_discovery/browser/double_fetcher.cc new file mode 100644 index 000000000000..fbf8697b0010 --- /dev/null +++ b/components/web_discovery/browser/double_fetcher.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/double_fetcher.h" + +#include + +#include "brave/components/web_discovery/browser/pref_names.h" +#include "brave/components/web_discovery/browser/request_queue.h" +#include "brave/components/web_discovery/browser/util.h" +#include "components/prefs/pref_service.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" +#include "services/network/public/cpp/simple_url_loader.h" +#include "services/network/public/mojom/url_response_head.mojom.h" + +namespace web_discovery { + +namespace { +constexpr char kUrlKey[] = "url"; +constexpr char kAssociatedDataKey[] = "assoc_data"; + +constexpr base::TimeDelta kRequestMaxAge = base::Hours(1); +constexpr base::TimeDelta kMinRequestInterval = + base::Minutes(1) - base::Seconds(5); +constexpr base::TimeDelta kMaxRequestInterval = + base::Minutes(1) + base::Seconds(5); +constexpr size_t kMaxRetries = 3; +constexpr size_t kMaxDoubleFetchResponseSize = 2 * 1024 * 1024; + +constexpr net::NetworkTrafficAnnotationTag kFetchNetworkTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("wdp_doublefetch", R"( + semantics { + sender: "Brave Web Discovery Double Fetch" + description: + "Retrieves a page of interest without cookies for + scraping and reporting via Web Discovery." + trigger: + "Requests are sent minutes after the original + page request is made by the user." + data: "Page data" + destination: WEBSITE + } + policy { + cookies_allowed: NO + setting: + "Users can opt-in or out via brave://settings/search" + })"); + +} // namespace + +DoubleFetcher::DoubleFetcher( + PrefService* profile_prefs, + network::SharedURLLoaderFactory* shared_url_loader_factory, + FetchedCallback callback) + : profile_prefs_(profile_prefs), + shared_url_loader_factory_(shared_url_loader_factory), + request_queue_(profile_prefs, + kScheduledDoubleFetches, + kRequestMaxAge, + kMinRequestInterval, + kMaxRequestInterval, + kMaxRetries, + base::BindRepeating(&DoubleFetcher::OnFetchTimer, + base::Unretained(this))), + callback_(callback) {} + +DoubleFetcher::~DoubleFetcher() = default; + +void DoubleFetcher::ScheduleDoubleFetch(const GURL& url, + base::Value associated_data) { + base::Value::Dict fetch_dict; + fetch_dict.Set(kUrlKey, url.spec()); + fetch_dict.Set(kAssociatedDataKey, std::move(associated_data)); + + request_queue_.ScheduleRequest(base::Value(std::move(fetch_dict))); +} + +void DoubleFetcher::OnFetchTimer(const base::Value& request_data) { + const auto* fetch_dict = request_data.GetIfDict(); + const auto* url_str = fetch_dict ? fetch_dict->FindString(kUrlKey) : nullptr; + if (!url_str) { + request_queue_.NotifyRequestComplete(true); + return; + } + + GURL url(*url_str); + auto resource_request = CreateResourceRequest(url); + url_loader_ = network::SimpleURLLoader::Create( + std::move(resource_request), kFetchNetworkTrafficAnnotation); + url_loader_->DownloadToString( + shared_url_loader_factory_.get(), + base::BindOnce(&DoubleFetcher::OnRequestComplete, base::Unretained(this), + url), + kMaxDoubleFetchResponseSize); +} + +void DoubleFetcher::OnRequestComplete( + GURL url, + std::optional response_body) { + auto result = ProcessCompletedRequest(&response_body); + + auto request_data = request_queue_.NotifyRequestComplete(result); + + if (request_data) { + const auto& request_dict = request_data->GetDict(); + const auto* assoc_data = request_dict.Find(kAssociatedDataKey); + if (assoc_data) { + callback_.Run(url, *assoc_data, response_body); + } + } +} + +bool DoubleFetcher::ProcessCompletedRequest( + std::optional* response_body) { + auto* response_info = url_loader_->ResponseInfo(); + if (!response_body || !response_info) { + return false; + } + auto response_code = response_info->headers->response_code(); + if (response_code < 200 || response_code >= 300) { + if (response_code >= 500) { + // Only retry failures due to server error + return false; + } + *response_body = std::nullopt; + } + return true; +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/double_fetcher.h b/components/web_discovery/browser/double_fetcher.h new file mode 100644 index 000000000000..e2b8c979a78d --- /dev/null +++ b/components/web_discovery/browser/double_fetcher.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_DOUBLE_FETCHER_H_ +#define BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_DOUBLE_FETCHER_H_ + +#include +#include +#include + +#include "base/memory/raw_ptr.h" +#include "base/values.h" +#include "brave/components/web_discovery/browser/request_queue.h" +#include "url/gurl.h" + +class PrefService; + +namespace network { +class SharedURLLoaderFactory; +class SimpleURLLoader; +} // namespace network + +namespace web_discovery { + +// Makes anonymous requests to relevant page URLs, without involvement of the +// user's session. In the case of search engine result pages, the result of the +// double fetch will scraped for search engine results for a future submission. +// Uses `RequestQueue` to persist and schedule double fetches. Requests +// will be sent on somewhat random intervals averaging to a minute. +class DoubleFetcher { + public: + using FetchedCallback = + base::RepeatingCallback response_body)>; + DoubleFetcher(PrefService* profile_prefs, + network::SharedURLLoaderFactory* shared_url_loader_factory, + FetchedCallback callback); + ~DoubleFetcher(); + + DoubleFetcher(const DoubleFetcher&) = delete; + DoubleFetcher& operator=(const DoubleFetcher&) = delete; + + // Queues a double fetch for a given URL. The associated data will be stored + // beside the queue request, and will be passed to the `FetchedCallback` + // upon completion. + void ScheduleDoubleFetch(const GURL& url, base::Value associated_data); + + private: + void OnFetchTimer(const base::Value& request_data); + void OnRequestComplete(GURL url, std::optional response_body); + bool ProcessCompletedRequest(std::optional* response_body); + + raw_ptr profile_prefs_; + raw_ptr shared_url_loader_factory_; + std::unique_ptr url_loader_; + + RequestQueue request_queue_; + + FetchedCallback callback_; +}; + +} // namespace web_discovery + +#endif // BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_DOUBLE_FETCHER_H_ diff --git a/components/web_discovery/browser/double_fetcher_unittest.cc b/components/web_discovery/browser/double_fetcher_unittest.cc new file mode 100644 index 000000000000..10af940221a1 --- /dev/null +++ b/components/web_discovery/browser/double_fetcher_unittest.cc @@ -0,0 +1,196 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/double_fetcher.h" + +#include +#include +#include +#include + +#include "base/functional/bind.h" +#include "base/memory/scoped_refptr.h" +#include "base/test/task_environment.h" +#include "brave/components/web_discovery/browser/web_discovery_service.h" +#include "components/prefs/testing_pref_service.h" +#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" +#include "services/network/test/test_url_loader_factory.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace web_discovery { + +namespace { +constexpr char kTestUrl[] = "https://example.com/test"; +constexpr char kTestResponseText[] = "test"; +} // namespace + +class WebDiscoveryDoubleFetcherTest : public testing::Test { + public: + WebDiscoveryDoubleFetcherTest() + : task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME), + shared_url_loader_factory_( + base::MakeRefCounted( + &url_loader_factory_)) {} + ~WebDiscoveryDoubleFetcherTest() override = default; + + // testing::Test: + void SetUp() override { + WebDiscoveryService::RegisterProfilePrefs(profile_prefs_.registry()); + + InitDoubleFetcher(); + SetUpResponse(net::HTTP_OK); + } + + protected: + void InitDoubleFetcher() { + double_fetcher_ = std::make_unique( + &profile_prefs_, shared_url_loader_factory_.get(), + base::BindRepeating(&WebDiscoveryDoubleFetcherTest::HandleDoubleFetch, + base::Unretained(this))); + } + + void SetUpResponse(net::HttpStatusCode status) { + url_loader_factory_.ClearResponses(); + url_loader_factory_.AddResponse(kTestUrl, kTestResponseText, status); + } + + base::test::TaskEnvironment task_environment_; + struct CompletedFetch { + GURL url; + base::Value associated_data; + std::optional response_body; + }; + std::unique_ptr double_fetcher_; + std::vector completed_fetches_; + network::TestURLLoaderFactory url_loader_factory_; + + private: + void HandleDoubleFetch(const GURL& url, + const base::Value& associated_data, + std::optional response_body) { + completed_fetches_.push_back( + CompletedFetch{.url = url, + .associated_data = associated_data.Clone(), + .response_body = response_body}); + } + + TestingPrefServiceSimple profile_prefs_; + scoped_refptr shared_url_loader_factory_; +}; + +TEST_F(WebDiscoveryDoubleFetcherTest, ScheduleAndFetch) { + GURL url(kTestUrl); + double_fetcher_->ScheduleDoubleFetch(url, base::Value("foo1 data")); + double_fetcher_->ScheduleDoubleFetch(url, base::Value("foo2 data")); + + task_environment_.FastForwardBy(base::Seconds(45)); + EXPECT_TRUE(completed_fetches_.empty()); + + task_environment_.FastForwardBy(base::Seconds(30)); + ASSERT_EQ(completed_fetches_.size(), 1u); + + EXPECT_EQ(completed_fetches_[0].url, url); + EXPECT_EQ(completed_fetches_[0].associated_data, base::Value("foo1 data")); + ASSERT_TRUE(completed_fetches_[0].response_body); + EXPECT_EQ(*completed_fetches_[0].response_body, kTestResponseText); + + completed_fetches_.clear(); + + task_environment_.FastForwardBy(base::Seconds(25)); + EXPECT_TRUE(completed_fetches_.empty()); + + task_environment_.FastForwardBy(base::Seconds(45)); + ASSERT_EQ(completed_fetches_.size(), 1u); + + EXPECT_EQ(completed_fetches_[0].url, url); + EXPECT_EQ(completed_fetches_[0].associated_data, base::Value("foo2 data")); + ASSERT_TRUE(completed_fetches_[0].response_body); + EXPECT_EQ(*completed_fetches_[0].response_body, kTestResponseText); + + completed_fetches_.clear(); + + task_environment_.FastForwardBy(base::Seconds(180)); + EXPECT_TRUE(completed_fetches_.empty()); +} + +TEST_F(WebDiscoveryDoubleFetcherTest, LoadScheduleFromStorageAndFetch) { + GURL url(kTestUrl); + double_fetcher_->ScheduleDoubleFetch(url, base::Value(1)); + double_fetcher_->ScheduleDoubleFetch(url, base::Value(2)); + + EXPECT_TRUE(completed_fetches_.empty()); + + InitDoubleFetcher(); + + task_environment_.FastForwardBy(base::Seconds(240)); + EXPECT_EQ(completed_fetches_.size(), 2u); +} + +TEST_F(WebDiscoveryDoubleFetcherTest, ScheduleRetry) { + GURL url(kTestUrl); + SetUpResponse(net::HTTP_INTERNAL_SERVER_ERROR); + double_fetcher_->ScheduleDoubleFetch(url, base::Value(true)); + + task_environment_.FastForwardBy(base::Seconds(75)); + EXPECT_TRUE(completed_fetches_.empty()); + + SetUpResponse(net::HTTP_OK); + task_environment_.FastForwardBy(base::Seconds(30)); + + ASSERT_EQ(completed_fetches_.size(), 1u); + + EXPECT_EQ(completed_fetches_[0].url, url); + EXPECT_EQ(completed_fetches_[0].associated_data, base::Value(true)); + ASSERT_TRUE(completed_fetches_[0].response_body); + EXPECT_EQ(*completed_fetches_[0].response_body, kTestResponseText); + + completed_fetches_.clear(); + + task_environment_.FastForwardBy(base::Seconds(180)); + EXPECT_TRUE(completed_fetches_.empty()); +} + +TEST_F(WebDiscoveryDoubleFetcherTest, ScheduleMaxRetries) { + GURL url(kTestUrl); + SetUpResponse(net::HTTP_INTERNAL_SERVER_ERROR); + double_fetcher_->ScheduleDoubleFetch(url, base::Value(true)); + + task_environment_.FastForwardBy(base::Seconds(70)); + EXPECT_TRUE(completed_fetches_.empty()); + + task_environment_.FastForwardBy(base::Seconds(120)); + ASSERT_EQ(completed_fetches_.size(), 1u); + + EXPECT_EQ(completed_fetches_[0].url, url); + EXPECT_EQ(completed_fetches_[0].associated_data, base::Value(true)); + ASSERT_FALSE(completed_fetches_[0].response_body); + + completed_fetches_.clear(); + + SetUpResponse(net::HTTP_OK); + task_environment_.FastForwardBy(base::Minutes(10)); + EXPECT_TRUE(completed_fetches_.empty()); +} + +TEST_F(WebDiscoveryDoubleFetcherTest, ScheduleNoRetry) { + GURL url(kTestUrl); + SetUpResponse(net::HTTP_NOT_FOUND); + double_fetcher_->ScheduleDoubleFetch(url, base::Value(123)); + + task_environment_.FastForwardBy(base::Seconds(70)); + ASSERT_EQ(completed_fetches_.size(), 1u); + + EXPECT_EQ(completed_fetches_[0].url, url); + EXPECT_EQ(completed_fetches_[0].associated_data, base::Value(123)); + ASSERT_FALSE(completed_fetches_[0].response_body); + + completed_fetches_.clear(); + + SetUpResponse(net::HTTP_OK); + task_environment_.FastForwardBy(base::Minutes(10)); + EXPECT_TRUE(completed_fetches_.empty()); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/ecdh_aes.cc b/components/web_discovery/browser/ecdh_aes.cc new file mode 100644 index 000000000000..4d0c59b38fec --- /dev/null +++ b/components/web_discovery/browser/ecdh_aes.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/ecdh_aes.h" + +#include + +#include "base/base64.h" +#include "base/logging.h" +#include "base/ranges/algorithm.h" +#include "base/strings/string_number_conversions.h" +#include "crypto/random.h" +#include "crypto/sha2.h" +#include "third_party/boringssl/src/include/openssl/aead.h" +#include "third_party/boringssl/src/include/openssl/ec.h" +#include "third_party/boringssl/src/include/openssl/ec_key.h" +#include "third_party/boringssl/src/include/openssl/ecdh.h" +#include "third_party/boringssl/src/include/openssl/nid.h" + +namespace web_discovery { + +namespace { + +constexpr size_t kAesKeySize = 16; +constexpr size_t kAesTagLength = 16; +constexpr size_t kIvSize = 12; +constexpr size_t kKeyMaterialSize = 32; +// P-256 field size * 2 + type byte +constexpr size_t kComponentOctSize = 32 * 2 + 1; +// type byte + public component + initialization vector +constexpr size_t kEncodedPubKeyAndIv = 1 + kComponentOctSize + kIvSize; +constexpr uint8_t kP256TypeByte = 0xea; + +bssl::UniquePtr CreateECKey() { + return bssl::UniquePtr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1)); +} + +} // namespace + +AESEncryptResult::AESEncryptResult(std::vector data, + std::string encoded_public_component_and_iv) + : data(data), + encoded_public_component_and_iv(encoded_public_component_and_iv) {} + +AESEncryptResult::~AESEncryptResult() = default; +AESEncryptResult::AESEncryptResult(const AESEncryptResult&) = default; + +std::optional DeriveAESKeyAndEncrypt( + const std::string& server_pub_key_b64, + const std::vector& data) { + auto server_pub_key_data = base::Base64Decode(server_pub_key_b64); + if (!server_pub_key_data) { + VLOG(1) << "ec p-256 public component not available or incorrect size"; + return std::nullopt; + } + + auto client_private_key = CreateECKey(); + + if (!client_private_key) { + VLOG(1) << "Failed to init P-256 curve"; + return std::nullopt; + } + + bssl::UniquePtr server_public_point(EC_POINT_new(EC_group_p256())); + if (!server_public_point) { + VLOG(1) << "Failed to init EC public point"; + return std::nullopt; + } + + if (!EC_POINT_oct2point(EC_group_p256(), server_public_point.get(), + server_pub_key_data->data(), + server_pub_key_data->size(), nullptr)) { + VLOG(1) << "Failed to load server public key data into EC point"; + return std::nullopt; + } + + if (!EC_KEY_generate_key(client_private_key.get())) { + VLOG(1) << "Failed to generate client EC key"; + return std::nullopt; + } + + uint8_t shared_key_material[kKeyMaterialSize]; + if (!ECDH_compute_key(shared_key_material, kKeyMaterialSize, + server_public_point.get(), client_private_key.get(), + nullptr)) { + VLOG(1) << "Failed to set derive key via ECDH"; + return std::nullopt; + } + + auto key_material_hash = crypto::SHA256Hash(shared_key_material); + + auto aes_key = std::vector(key_material_hash.begin(), + key_material_hash.begin() + kAesKeySize); + auto* algo = EVP_aead_aes_128_gcm(); + + bssl::ScopedEVP_AEAD_CTX ctx; + if (!EVP_AEAD_CTX_init(ctx.get(), algo, aes_key.data(), aes_key.size(), + kAesTagLength, nullptr)) { + VLOG(1) << "Failed to init AEAD context"; + return std::nullopt; + } + + size_t len; + std::array iv; + + crypto::RandBytes(iv); + + std::vector output(data.size() + EVP_AEAD_max_overhead(algo)); + if (!EVP_AEAD_CTX_seal(ctx.get(), output.data(), &len, output.size(), + iv.data(), iv.size(), data.data(), data.size(), + nullptr, 0)) { + VLOG(1) << "Failed to encrypt via AES"; + return std::nullopt; + } + + output.resize(len); + + std::array public_component_and_iv; + public_component_and_iv[0] = kP256TypeByte; + + if (!EC_POINT_point2oct( + EC_group_p256(), EC_KEY_get0_public_key(client_private_key.get()), + POINT_CONVERSION_UNCOMPRESSED, public_component_and_iv.data() + 1, + kComponentOctSize, nullptr)) { + VLOG(1) << "Failed to export EC public point/key"; + return std::nullopt; + } + + base::ranges::copy(iv.begin(), iv.end(), + public_component_and_iv.begin() + kComponentOctSize + 1); + + return std::make_optional( + output, base::Base64Encode(public_component_and_iv)); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/ecdh_aes.h b/components/web_discovery/browser/ecdh_aes.h new file mode 100644 index 000000000000..d9ce7604888f --- /dev/null +++ b/components/web_discovery/browser/ecdh_aes.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_ECDH_AES_H_ +#define BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_ECDH_AES_H_ + +#include +#include +#include + +namespace web_discovery { + +struct AESEncryptResult { + AESEncryptResult(std::vector data, + std::string encoded_public_component_and_iv); + ~AESEncryptResult(); + + AESEncryptResult(const AESEncryptResult&); + + std::vector data; + std::string encoded_public_component_and_iv; +}; + +std::optional DeriveAESKeyAndEncrypt( + const std::string& server_pub_key_b64, + const std::vector& data); + +} // namespace web_discovery + +#endif // BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_ECDH_AES_H_ diff --git a/components/web_discovery/browser/pref_names.h b/components/web_discovery/browser/pref_names.h index b2f39a2cb7f6..5c957a40b6f4 100644 --- a/components/web_discovery/browser/pref_names.h +++ b/components/web_discovery/browser/pref_names.h @@ -23,6 +23,14 @@ inline constexpr char kCredentialRSAPublicKey[] = inline constexpr char kAnonymousCredentialsDict[] = "brave.web_discovery.anon_creds"; +inline constexpr char kScheduledDoubleFetches[] = + "brave.web_discovery.scheduled_double_fetches"; +inline constexpr char kScheduledReports[] = + "brave.web_discovery.scheduled_reports"; +inline constexpr char kUsedBasenameCounts[] = + "brave.web_discovery.used_basename_counts"; +inline constexpr char kPageCounts[] = "brave.web_discovery.page_counts"; + // Local state inline constexpr char kPatternsRetrievalTime[] = "brave.web_discovery.patterns_retrieval_time"; diff --git a/components/web_discovery/browser/reporter.cc b/components/web_discovery/browser/reporter.cc new file mode 100644 index 000000000000..31aa011397f7 --- /dev/null +++ b/components/web_discovery/browser/reporter.cc @@ -0,0 +1,284 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/reporter.h" + +#include + +#include "base/containers/span_writer.h" +#include "base/json/json_writer.h" +#include "base/numerics/byte_conversions.h" +#include "base/rand_util.h" +#include "base/task/thread_pool.h" +#include "base/threading/thread_restrictions.h" +#include "brave/components/web_discovery/browser/pref_names.h" +#include "brave/components/web_discovery/browser/signature_basename.h" +#include "brave/components/web_discovery/browser/util.h" +#include "crypto/sha2.h" +#include "services/network/public/cpp/resource_request_body.h" +#include "services/network/public/cpp/shared_url_loader_factory.h" +#include "services/network/public/cpp/simple_url_loader.h" +#include "third_party/zlib/google/compression_utils_portable.h" +#include "third_party/zlib/zlib.h" + +namespace web_discovery { + +namespace { + +constexpr net::NetworkTrafficAnnotationTag kSubmitNetworkTrafficAnnotation = + net::DefineNetworkTrafficAnnotation("wdp_submit", R"( + semantics { + sender: "Brave Web Discovery Submission" + description: + "Sends search engine results & page interaction metrics + that are deemed private by risk assessment heuristics." + trigger: + "Requests are automatically sent every minute " + "while Brave is running, and as content is collected." + data: "Search engine results & page interaction metrics" + destination: WEBSITE + } + policy { + cookies_allowed: NO + setting: + "Users can opt-in or out via brave://settings/search" + })"); + +constexpr base::TimeDelta kRequestMaxAge = base::Hours(36); +constexpr base::TimeDelta kMinRequestInterval = + base::Minutes(1) - base::Seconds(5); +constexpr base::TimeDelta kMaxRequestInterval = + base::Minutes(1) + base::Seconds(5); +constexpr size_t kMaxRetries = 10; + +constexpr char kTypeField[] = "type"; +constexpr char kWdpType[] = "wdp"; +constexpr char kChannelField[] = "channel"; +constexpr char kBraveChannel[] = "brave"; +constexpr char kReporterVersionField[] = "ver"; +constexpr char kCurrentReporterVersion[] = "1.0"; +constexpr char kAntiDuplicatesField[] = "anti-duplicates"; +constexpr char kTimestampField[] = "ts"; +constexpr int kMaxAntiDuplicatesNonce = 10000000; +constexpr char kSenderField[] = "sender"; +constexpr char kHpnSenderValue[] = "hpnv2"; + +constexpr uint8_t kSignedMessageId = 0x03; +constexpr uint8_t kCompressedMessageId = 0x80; +// id byte + basename count + signature +constexpr size_t kSignedMessageMetadataSize = 1 + 8 + 389; +constexpr size_t kMaxCompressedMessageSize = 32767; + +constexpr char kSubmitPath[] = "/"; +constexpr char kMessageContentType[] = "application/octet-stream"; +constexpr char kKeyDateHeader[] = "Key-Date"; +constexpr char kEncryptionHeader[] = "Encryption"; + +base::Value GenerateFinalPayload(const base::Value::Dict& pre_payload) { + base::Value::Dict result = pre_payload.Clone(); + + result.Set(kTypeField, kWdpType); + result.Set(kReporterVersionField, kCurrentReporterVersion); + result.Set(kSenderField, kHpnSenderValue); + result.Set(kTimestampField, FormatServerDate(base::Time::Now())); + result.Set(kAntiDuplicatesField, base::RandInt(0, kMaxAntiDuplicatesNonce)); + result.Set(kChannelField, kBraveChannel); + + return base::Value(std::move(result)); +} + +std::optional CompressAndEncrypt( + std::vector full_signed_message, + std::string server_pub_key) { + base::AssertLongCPUWorkAllowed(); + uLongf compressed_data_size = compressBound(full_signed_message.size()); + std::vector compressed_data(compressed_data_size + 2); + if (zlib_internal::CompressHelper( + zlib_internal::ZLIB, compressed_data.data() + 2, + &compressed_data_size, full_signed_message.data(), + full_signed_message.size(), Z_DEFAULT_COMPRESSION, nullptr, + nullptr) != Z_OK) { + VLOG(1) << "Failed to compress payload"; + return std::nullopt; + } + compressed_data.resize(compressed_data_size + 2); + if (compressed_data_size > kMaxCompressedMessageSize) { + VLOG(1) << "Compressed payload exceeds limit of " + << kMaxCompressedMessageSize << " bytes"; + return std::nullopt; + } + base::ranges::copy(base::U16ToBigEndian(compressed_data_size), + compressed_data.begin()); + compressed_data[0] |= kCompressedMessageId; + return DeriveAESKeyAndEncrypt(server_pub_key, compressed_data); +} + +} // namespace + +Reporter::Reporter(PrefService* profile_prefs, + network::SharedURLLoaderFactory* shared_url_loader_factory, + CredentialSigner* credential_signer, + RegexUtil* regex_util, + const ServerConfigLoader* server_config_loader) + : profile_prefs_(profile_prefs), + shared_url_loader_factory_(shared_url_loader_factory), + credential_signer_(credential_signer), + regex_util_(regex_util), + server_config_loader_(server_config_loader), + sequenced_task_runner_(base::ThreadPool::CreateSequencedTaskRunner({})), + request_queue_(profile_prefs, + kScheduledReports, + kRequestMaxAge, + kMinRequestInterval, + kMaxRequestInterval, + kMaxRetries, + base::BindRepeating(&Reporter::PrepareRequest, + base::Unretained(this))) { + submit_url_ = GURL(GetAnonymousHPNHost() + kSubmitPath); +} + +Reporter::~Reporter() = default; + +void Reporter::ScheduleSend(base::Value::Dict payload) { + request_queue_.ScheduleRequest(base::Value(std::move(payload))); +} + +void Reporter::PrepareRequest(const base::Value& request_data) { + VLOG(1) << "Preparing request"; + if (!credential_signer_->CredentialExistsForToday()) { + // Backoff until credential is available to today + VLOG(1) << "Credential does not exist for today"; + request_queue_.NotifyRequestComplete(false); + return; + } + const auto* payload_dict = request_data.GetIfDict(); + if (!payload_dict) { + // Drop request due to bad data + VLOG(1) << "Payload is not a dictionary"; + request_queue_.NotifyRequestComplete(true); + return; + } + auto basename_result = GenerateBasename( + profile_prefs_, server_config_loader_->GetLastServerConfig(), + *regex_util_, *payload_dict); + if (!basename_result) { + // Drop request due to exceeded basename quota + VLOG(1) << "Failed to generate basename"; + request_queue_.NotifyRequestComplete(true); + return; + } + auto final_payload = GenerateFinalPayload(*payload_dict); + + std::string final_payload_json; + if (!base::JSONWriter::Write(final_payload, &final_payload_json)) { + request_queue_.NotifyRequestComplete(true); + return; + } + + auto payload_hash = crypto::SHA256HashString(final_payload_json); + credential_signer_->Sign( + std::vector(payload_hash.begin(), payload_hash.end()), + basename_result->basename, + base::BindOnce(&Reporter::OnRequestSigned, base::Unretained(this), + final_payload_json, basename_result->count_tag_hash, + basename_result->count)); +} + +void Reporter::OnRequestSigned( + std::string final_payload_json, + uint32_t count_tag_hash, + size_t basename_count, + std::optional> signature) { + if (!signature) { + request_queue_.NotifyRequestComplete(false); + return; + } + const auto& server_config = server_config_loader_->GetLastServerConfig(); + auto pub_key = + server_config.pub_keys.find(FormatServerDate(base::Time::Now())); + if (pub_key == server_config.pub_keys.end()) { + VLOG(1) << "No ECDH server public key available"; + request_queue_.NotifyRequestComplete(false); + return; + } + std::vector full_signed_message(kSignedMessageMetadataSize + + final_payload_json.size()); + base::SpanWriter message_writer(full_signed_message); + if (!message_writer.WriteU8BigEndian(kSignedMessageId) || + !message_writer.Write(base::span( + reinterpret_cast(final_payload_json.data()), + final_payload_json.size())) || + !message_writer.Write(base::DoubleToBigEndian(basename_count)) || + !message_writer.Write(*signature)) { + VLOG(1) << "Failed to pack signed message"; + request_queue_.NotifyRequestComplete(true); + return; + } + sequenced_task_runner_->PostTaskAndReplyWithResult( + FROM_HERE, + base::BindOnce(&CompressAndEncrypt, full_signed_message, pub_key->second), + base::BindOnce(&Reporter::OnRequestCompressedAndEncrypted, + + weak_ptr_factory_.GetWeakPtr(), count_tag_hash, + basename_count)); +} + +void Reporter::OnRequestCompressedAndEncrypted( + uint32_t count_tag_hash, + size_t basename_count, + std::optional result) { + if (!result) { + request_queue_.NotifyRequestComplete(true); + return; + } + auto request = CreateResourceRequest(submit_url_); + request->method = net::HttpRequestHeaders::kPostMethod; + request->headers.SetHeader(kKeyDateHeader, + FormatServerDate(base::Time::Now())); + request->headers.SetHeader(kEncryptionHeader, + result->encoded_public_component_and_iv); + request->headers.SetHeader(kVersionHeader, + base::NumberToString(kCurrentVersion)); + + VLOG(1) << "Sending message"; + url_loader_ = network::SimpleURLLoader::Create( + std::move(request), kSubmitNetworkTrafficAnnotation); + url_loader_->AttachStringForUpload( + std::string(result->data.begin(), result->data.end()), + kMessageContentType); + url_loader_->DownloadHeadersOnly( + shared_url_loader_factory_.get(), + base::BindOnce(&Reporter::OnRequestComplete, base::Unretained(this), + count_tag_hash, basename_count)); +} + +void Reporter::OnRequestComplete( + uint32_t count_tag_hash, + size_t basename_count, + scoped_refptr headers) { + auto result = ValidateResponse(headers); + VLOG(1) << "Submission result: " << result; + if (result) { + SaveBasenameCount(profile_prefs_, count_tag_hash, basename_count); + } + request_queue_.NotifyRequestComplete(result); +} + +bool Reporter::ValidateResponse( + scoped_refptr headers) { + if (!headers) { + return false; + } + auto response_code = headers->response_code(); + if (response_code < 200 || response_code >= 300) { + if (response_code >= 500) { + // Only retry failures due to server error + return false; + } + } + return true; +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/reporter.h b/components/web_discovery/browser/reporter.h new file mode 100644 index 000000000000..54c318d48fec --- /dev/null +++ b/components/web_discovery/browser/reporter.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REPORTER_H_ +#define BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REPORTER_H_ + +#include +#include +#include +#include + +#include "base/memory/raw_ptr.h" +#include "base/values.h" +#include "brave/components/web_discovery/browser/credential_signer.h" +#include "brave/components/web_discovery/browser/ecdh_aes.h" +#include "brave/components/web_discovery/browser/regex_util.h" +#include "brave/components/web_discovery/browser/request_queue.h" +#include "brave/components/web_discovery/browser/server_config_loader.h" +#include "net/http/http_response_headers.h" + +class PrefService; + +namespace network { +class SharedURLLoaderFactory; +class SimpleURLLoader; +} // namespace network + +namespace web_discovery { + +// Handles all functions required for reporting generated payloads: +// - zlib compression +// - ECDH key derivation + key exchange +// - AES encryption (to prevent eavesdropping by the server proxy) +// - signing the request using anonymous credentials from the +// `CredentialManager` (to prevent Sybil attacks on the server) +// - performing the request for submission +// Uses `RequestQueue` to persist and schedule submissions. Reports +// will be processed on somewhat random intervals averaging to a minute. +class Reporter { + public: + Reporter(PrefService* profile_prefs, + network::SharedURLLoaderFactory* shared_url_loader_factory, + CredentialSigner* credential_signer, + RegexUtil* regex_util, + const ServerConfigLoader* server_config_loader); + ~Reporter(); + + Reporter(const Reporter&) = delete; + Reporter& operator=(const Reporter&) = delete; + + // Schedule a generated payload for submission. + void ScheduleSend(base::Value::Dict payload); + + private: + void PrepareRequest(const base::Value& request_data); + void OnRequestSigned(std::string final_payload_json, + uint32_t count_tag_hash, + size_t basename_count, + std::optional> signature); + void OnRequestCompressedAndEncrypted(uint32_t count_tag_hash, + size_t basename_count, + std::optional result); + void OnRequestComplete(uint32_t count_tag_hash, + size_t basename_count, + scoped_refptr headers); + bool ValidateResponse(scoped_refptr headers); + + GURL submit_url_; + + raw_ptr profile_prefs_; + raw_ptr shared_url_loader_factory_; + + raw_ptr credential_signer_; + raw_ptr regex_util_; + raw_ptr server_config_loader_; + + scoped_refptr sequenced_task_runner_; + + RequestQueue request_queue_; + + std::unique_ptr url_loader_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace web_discovery + +#endif // BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REPORTER_H_ diff --git a/components/web_discovery/browser/reporter_unittest.cc b/components/web_discovery/browser/reporter_unittest.cc new file mode 100644 index 000000000000..57a81ca0c045 --- /dev/null +++ b/components/web_discovery/browser/reporter_unittest.cc @@ -0,0 +1,259 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/reporter.h" + +#include +#include + +#include "base/base64.h" +#include "base/test/task_environment.h" +#include "brave/components/web_discovery/browser/regex_util.h" +#include "brave/components/web_discovery/browser/server_config_loader.h" +#include "brave/components/web_discovery/browser/util.h" +#include "brave/components/web_discovery/browser/web_discovery_service.h" +#include "components/prefs/testing_pref_service.h" +#include "net/http/http_status_code.h" +#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" +#include "services/network/test/test_url_loader_factory.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace web_discovery { + +namespace { +constexpr char kTestPubKey[] = + "BECQDFoOR0DE3wLaDidGAC/2Mpgjasf9QgJDGGLTkTdll+pW2S/" + "RgX0pkFyDjQZc6efyX3RGQKJ2cq8HOB8vZOo="; +} + +class WebDiscoveryReporterTest : public testing::Test { + public: + WebDiscoveryReporterTest() + : task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME), + shared_url_loader_factory_( + base::MakeRefCounted( + &url_loader_factory_)) {} + ~WebDiscoveryReporterTest() override = default; + + // testing::Test: + void SetUp() override { + WebDiscoveryService::RegisterProfilePrefs(profile_prefs_.registry()); + auto server_config = std::make_unique(); + + auto action_config = std::make_unique(); + action_config->keys.push_back("q->url"); + action_config->period = 24; + action_config->limit = 3; + server_config->source_map_actions["query"] = std::move(action_config); + + for (size_t i = 0; i < 3; i++) { + base::Time date = base::Time::Now() + base::Days(i); + server_config->pub_keys[FormatServerDate(date)] = kTestPubKey; + } + + server_config_loader_ = std::make_unique( + nullptr, base::FilePath(), nullptr, base::DoNothing(), + base::DoNothing()); + server_config_loader_->SetLastServerConfigForTesting( + std::move(server_config)); + + url_loader_factory_.SetInterceptor(base::BindRepeating( + &WebDiscoveryReporterTest::HandleRequest, base::Unretained(this))); + + AddCredentialForToday(); + SetupReporter(); + } + + protected: + class TestCredentialSigner : public CredentialSigner { + public: + bool CredentialExistsForToday() override { + std::string today = FormatServerDate(base::Time::Now()); + return allowed_credentials_.contains(today); + } + + bool Sign(std::vector msg, + std::vector basename, + SignCallback callback) override { + if (CredentialExistsForToday()) { + std::vector dummy_signature( + {static_cast(sign_count_ + 1)}); + std::move(callback).Run(std::move(dummy_signature)); + sign_count_++; + } else { + std::move(callback).Run(std::nullopt); + } + return true; + } + + size_t sign_count_ = 0; + base::flat_set allowed_credentials_; + }; + + void SetupReporter() { + reporter_ = std::make_unique( + &profile_prefs_, shared_url_loader_factory_.get(), &credential_signer_, + ®ex_util_, server_config_loader_.get()); + } + + void AddCredentialForToday() { + std::string today = FormatServerDate(base::Time::Now()); + credential_signer_.allowed_credentials_.insert(today); + } + + base::Value::Dict GenerateTestPayload() { + base::Value::Dict payload; + base::Value::Dict inner_payload; + inner_payload.Set("q", "test query"); + payload.Set("payload", std::move(inner_payload)); + payload.Set("action", "query"); + return payload; + } + + base::test::TaskEnvironment task_environment_; + std::unique_ptr reporter_; + TestCredentialSigner credential_signer_; + size_t report_requests_made_ = 0; + net::HttpStatusCode submit_status_code_ = net::HTTP_OK; + + private: + void HandleRequest(const network::ResourceRequest& request) { + url_loader_factory_.ClearResponses(); + + EXPECT_EQ(request.url.spec(), GetAnonymousHPNHost() + "/"); + EXPECT_EQ(request.method, net::HttpRequestHeaders::kPostMethod); + std::string key_date, encryption, version; + request.headers.GetHeader("Key-Date", &key_date); + request.headers.GetHeader("Encryption", &encryption); + request.headers.GetHeader(kVersionHeader, &version); + EXPECT_EQ(key_date, FormatServerDate(base::Time::Now())); + auto decoded_pubkey_and_iv = base::Base64Decode(encryption); + ASSERT_TRUE(decoded_pubkey_and_iv); + EXPECT_EQ(decoded_pubkey_and_iv->size(), 78u); + EXPECT_EQ(version, base::NumberToString(kCurrentVersion)); + + std::string response; + const auto* elements = request.request_body->elements(); + ASSERT_EQ(elements->size(), 1u); + ASSERT_EQ(elements->at(0).type(), network::DataElement::Tag::kBytes); + auto body = elements->at(0).As().bytes(); + EXPECT_FALSE(body.empty()); + + url_loader_factory_.AddResponse(request.url.spec(), "", + submit_status_code_); + report_requests_made_++; + } + + std::unique_ptr server_config_loader_; + RegexUtil regex_util_; + TestingPrefServiceSimple profile_prefs_; + network::TestURLLoaderFactory url_loader_factory_; + scoped_refptr shared_url_loader_factory_; +}; + +TEST_F(WebDiscoveryReporterTest, BasicReport) { + reporter_->ScheduleSend(GenerateTestPayload()); + reporter_->ScheduleSend(GenerateTestPayload()); + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + task_environment_.FastForwardBy(base::Seconds(30)); + + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + task_environment_.FastForwardBy(base::Seconds(60)); + + EXPECT_EQ(report_requests_made_, 1u); + EXPECT_EQ(credential_signer_.sign_count_, 1u); + + task_environment_.FastForwardBy(base::Seconds(80)); + + EXPECT_EQ(report_requests_made_, 2u); + EXPECT_EQ(credential_signer_.sign_count_, 2u); + report_requests_made_ = 0; + credential_signer_.sign_count_ = 0; + + task_environment_.FastForwardBy(base::Minutes(5)); + + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); +} + +TEST_F(WebDiscoveryReporterTest, LoadReportFromStorage) { + reporter_->ScheduleSend(GenerateTestPayload()); + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + SetupReporter(); + + task_environment_.FastForwardBy(base::Seconds(30)); + + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + task_environment_.FastForwardBy(base::Seconds(60)); + + EXPECT_EQ(report_requests_made_, 1u); + EXPECT_EQ(credential_signer_.sign_count_, 1u); +} + +TEST_F(WebDiscoveryReporterTest, CredentialUnavailableRetry) { + task_environment_.FastForwardBy(base::Days(1)); + + reporter_->ScheduleSend(GenerateTestPayload()); + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + task_environment_.FastForwardBy(base::Seconds(150)); + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); + + AddCredentialForToday(); + task_environment_.FastForwardBy(base::Seconds(120)); + + EXPECT_EQ(report_requests_made_, 1u); + EXPECT_EQ(credential_signer_.sign_count_, 1u); + report_requests_made_ = 0; + credential_signer_.sign_count_ = 0; + + task_environment_.FastForwardBy(base::Minutes(5)); + + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); +} + +TEST_F(WebDiscoveryReporterTest, ServerUnavailableRetry) { + submit_status_code_ = net::HTTP_INTERNAL_SERVER_ERROR; + reporter_->ScheduleSend(GenerateTestPayload()); + + task_environment_.FastForwardBy(base::Seconds(80)); + EXPECT_GE(report_requests_made_, 1u); + EXPECT_GE(credential_signer_.sign_count_, 1u); + + size_t prev_report_requests_made = report_requests_made_; + size_t prev_sign_count = credential_signer_.sign_count_; + task_environment_.FastForwardBy(base::Seconds(100)); + + EXPECT_GT(report_requests_made_, prev_report_requests_made); + EXPECT_GT(credential_signer_.sign_count_, prev_sign_count); + report_requests_made_ = 0; + credential_signer_.sign_count_ = 0; + + submit_status_code_ = net::HTTP_OK; + task_environment_.FastForwardBy(base::Seconds(100)); + + EXPECT_EQ(report_requests_made_, 1u); + EXPECT_EQ(credential_signer_.sign_count_, 1u); + report_requests_made_ = 0; + credential_signer_.sign_count_ = 0; + + task_environment_.FastForwardBy(base::Minutes(5)); + + EXPECT_EQ(report_requests_made_, 0u); + EXPECT_EQ(credential_signer_.sign_count_, 0u); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/request_queue.cc b/components/web_discovery/browser/request_queue.cc new file mode 100644 index 000000000000..f5743a3500c6 --- /dev/null +++ b/components/web_discovery/browser/request_queue.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/request_queue.h" + +#include + +#include "base/rand_util.h" +#include "brave/components/web_discovery/browser/util.h" +#include "components/prefs/scoped_user_pref_update.h" + +namespace web_discovery { + +namespace { + +constexpr char kRequestTimeKey[] = "request_time"; +constexpr char kRetriesKey[] = "retries"; +constexpr char kDataKey[] = "data"; + +} // namespace + +RequestQueue::RequestQueue( + PrefService* profile_prefs, + const char* list_pref_name, + base::TimeDelta request_max_age, + base::TimeDelta min_request_interval, + base::TimeDelta max_request_interval, + size_t max_retries, + base::RepeatingCallback start_request_callback) + : profile_prefs_(profile_prefs), + list_pref_name_(list_pref_name), + backoff_entry_(&kBackoffPolicy), + request_max_age_(request_max_age), + min_request_interval_(min_request_interval), + max_request_interval_(max_request_interval), + max_retries_(max_retries), + start_request_callback_(start_request_callback) { + StartFetchTimer(false); +} + +RequestQueue::~RequestQueue() = default; + +void RequestQueue::ScheduleRequest(base::Value request_data) { + base::Value::Dict fetch_dict; + fetch_dict.Set(kDataKey, std::move(request_data)); + fetch_dict.Set(kRequestTimeKey, + static_cast(base::Time::Now().ToTimeT())); + + ScopedListPrefUpdate update(profile_prefs_, list_pref_name_); + update->Append(std::move(fetch_dict)); + + if (!fetch_timer_.IsRunning()) { + StartFetchTimer(false); + } +} + +std::optional RequestQueue::NotifyRequestComplete(bool success) { + backoff_entry_.InformOfRequest(success); + + ScopedListPrefUpdate update(profile_prefs_, list_pref_name_); + auto& request_dict = update->front().GetDict(); + + std::optional removed_value; + bool use_backoff_delta = false; + bool should_remove = success; + + if (!success) { + use_backoff_delta = true; + auto retries = request_dict.FindInt(kRetriesKey); + if (retries && static_cast(*retries + 1) >= max_retries_) { + should_remove = true; + } else { + request_dict.Set(kRetriesKey, retries.value_or(0) + 1); + } + } + + if (should_remove) { + auto* data = request_dict.Find(kDataKey); + removed_value = data ? data->Clone() : base::Value(); + update->erase(update->begin()); + } + + StartFetchTimer(use_backoff_delta); + return removed_value; +} + +void RequestQueue::OnFetchTimer() { + ScopedListPrefUpdate update(profile_prefs_, list_pref_name_); + for (auto it = update->begin(); it != update->end();) { + const auto* fetch_dict = it->GetIfDict(); + const auto request_time = + fetch_dict ? fetch_dict->FindDouble(kRequestTimeKey) : std::nullopt; + const auto* data = fetch_dict ? fetch_dict->Find(kDataKey) : nullptr; + if (!request_time || + (base::Time::Now() - base::Time::FromTimeT(static_cast( + *request_time))) > request_max_age_ || + !data) { + it = update->erase(it); + continue; + } + start_request_callback_.Run(*data); + return; + } +} + +void RequestQueue::StartFetchTimer(bool use_backoff_delta) { + base::TimeDelta delta; + if (use_backoff_delta) { + delta = backoff_entry_.GetTimeUntilRelease(); + } else { + delta = base::RandTimeDelta(min_request_interval_, max_request_interval_); + } + fetch_timer_.Start( + FROM_HERE, delta, + base::BindOnce(&RequestQueue::OnFetchTimer, base::Unretained(this))); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/request_queue.h b/components/web_discovery/browser/request_queue.h new file mode 100644 index 000000000000..fd385bcee22a --- /dev/null +++ b/components/web_discovery/browser/request_queue.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REQUEST_QUEUE_H_ +#define BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REQUEST_QUEUE_H_ + +#include "base/functional/callback.h" +#include "base/memory/raw_ptr.h" +#include "base/timer/timer.h" +#include "base/values.h" +#include "net/base/backoff_entry.h" + +class PrefService; + +namespace web_discovery { + +// Persists and schedules requests on randomized intervals within +// an interval range. If request failures exceed the threshold defined in +// `max_retries`, the request will be dropped from the list. If a persisted +// request age exceeds `request_max_age`, the request will be dropped. +class RequestQueue { + public: + RequestQueue( + PrefService* profile_prefs, + const char* list_pref_name, + base::TimeDelta request_max_age, + base::TimeDelta min_request_interval, + base::TimeDelta max_request_interval, + size_t max_retries, + base::RepeatingCallback start_request_callback); + ~RequestQueue(); + + RequestQueue(const RequestQueue&) = delete; + RequestQueue& operator=(const RequestQueue&) = delete; + + // Persist and schedule a request. The arbitrary data will be passed + // to `start_request_callback` on the scheduled interval. + void ScheduleRequest(base::Value request_data); + // Returns data value if request is deleted from queue, due to the retry limit + // or success + std::optional NotifyRequestComplete(bool success); + + private: + void OnFetchTimer(); + void StartFetchTimer(bool use_backoff_delta); + + raw_ptr profile_prefs_; + const char* list_pref_name_; + + net::BackoffEntry backoff_entry_; + + base::TimeDelta request_max_age_; + base::TimeDelta min_request_interval_; + base::TimeDelta max_request_interval_; + size_t max_retries_; + base::RepeatingCallback start_request_callback_; + + base::OneShotTimer fetch_timer_; +}; + +} // namespace web_discovery + +#endif // BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_REQUEST_QUEUE_H_ diff --git a/components/web_discovery/browser/signature_basename.cc b/components/web_discovery/browser/signature_basename.cc new file mode 100644 index 000000000000..06cc4127ebe7 --- /dev/null +++ b/components/web_discovery/browser/signature_basename.cc @@ -0,0 +1,239 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/signature_basename.h" + +#include +#include + +#include "base/hash/hash.h" +#include "base/json/json_writer.h" +#include "base/rand_util.h" +#include "base/strings/string_split.h" +#include "base/strings/string_util.h" +#include "brave/components/web_discovery/browser/payload_generator.h" +#include "brave/components/web_discovery/browser/pref_names.h" +#include "brave/components/web_discovery/browser/server_config_loader.h" +#include "components/prefs/scoped_user_pref_update.h" +#include "crypto/sha2.h" + +namespace web_discovery { + +namespace { + +constexpr char kUrlNormalizationFunc[] = "url"; +constexpr char kFlattenObjNormalizationFunc[] = "obj"; +constexpr size_t kMsInHour = 60 * 60 * 1000; + +constexpr char kExpiresAtKey[] = "expires_at"; +constexpr char kUsedCountsKey[] = "counts"; + +void RecurseFlattenObject(const base::Value& value, + const base::Value::List& parent_keys, + base::Value::List& output) { + if (value.is_dict()) { + const auto& dict = value.GetDict(); + base::flat_set keys; + // insert into set so we can sort keys + for (const auto [key, _] : dict) { + keys.insert(key); + } + for (const auto& key : keys) { + base::Value::List next_parent_keys = parent_keys.Clone(); + next_parent_keys.Append(key); + RecurseFlattenObject(*dict.Find(key), next_parent_keys, output); + } + } else if (value.is_list()) { + const auto& list = value.GetList(); + for (size_t i = 0; i < list.size(); i++) { + base::Value::List next_parent_keys = parent_keys.Clone(); + next_parent_keys.Append(base::NumberToString(i)); + RecurseFlattenObject(list[i], next_parent_keys, output); + } + } else { + base::Value::List flattened_value; + flattened_value.Append(parent_keys.Clone()); + flattened_value.Append(value.Clone()); + output.Append(std::move(flattened_value)); + } +} + +base::Value FlattenObject(const base::Value& obj) { + base::Value::List result; + RecurseFlattenObject(obj, base::Value::List(), result); + return base::Value(std::move(result)); +} + +base::Value CleanURL(RegexUtil& regex_util, const base::Value& url) { + if (!url.is_string()) { + return base::Value(); + } + auto url_str = base::ToLowerASCII(url.GetString()); + base::RemoveChars(url_str, " ", &url_str); + base::ReplaceSubstringsAfterOffset(&url_str, 0, "https://", ""); + base::ReplaceSubstringsAfterOffset(&url_str, 0, "http://", ""); + base::ReplaceSubstringsAfterOffset(&url_str, 0, "www.", ""); + + regex_util.RemovePunctuation(url_str); + return base::Value(std::move(url_str)); +} + +int GetPeriodHoursSinceEpoch(size_t period_hours) { + auto hours_since_epoch = + base::Time::Now().InMillisecondsSinceUnixEpoch() / kMsInHour; + auto epoch_period_hours = period_hours * (hours_since_epoch / period_hours); + return epoch_period_hours; +} + +std::optional GetBasenameCount(PrefService* profile_prefs, + uint32_t count_tag_hash, + const SourceMapActionConfig& action_config, + size_t period_hours) { + // clean up expired counts + ScopedDictPrefUpdate update(profile_prefs, kUsedBasenameCounts); + base::Time now = base::Time::Now(); + for (auto it = update->begin(); it != update->end();) { + const auto* value_dict = it->second.GetIfDict(); + if (!value_dict) { + it = update->erase(it); + continue; + } + const auto expire_time = value_dict->FindDouble(kExpiresAtKey); + if (!expire_time || + now >= base::Time::FromTimeT(static_cast(*expire_time))) { + it = update->erase(it); + continue; + } + it++; + } + + auto count_tag_hash_str = base::NumberToString(count_tag_hash); + auto* count_dict = update->EnsureDict(count_tag_hash_str); + if (!count_dict->contains(kExpiresAtKey)) { + auto expire_time = + base::Time::FromMillisecondsSinceUnixEpoch(static_cast( + (period_hours + action_config.period) * kMsInHour)); + count_dict->Set(kExpiresAtKey, static_cast(expire_time.ToTimeT())); + } + + auto* used_counts_list = count_dict->EnsureList(kUsedCountsKey); + if (used_counts_list->size() >= action_config.limit) { + VLOG(1) << "No basename counts left"; + return std::nullopt; + } + + while (true) { + auto count = base::RandInt(0, action_config.limit - 1); + if (base::ranges::find(used_counts_list->begin(), used_counts_list->end(), + count) != used_counts_list->end()) { + continue; + } + return count; + } +} + +} // namespace + +BasenameResult::BasenameResult(std::vector basename, + size_t count, + uint32_t count_tag_hash) + : basename(basename), count(count), count_tag_hash(count_tag_hash) {} + +BasenameResult::~BasenameResult() = default; + +std::optional GenerateBasename( + PrefService* profile_prefs, + const ServerConfig& server_config, + RegexUtil& regex_util, + const base::Value::Dict& payload) { + const std::string* action = payload.FindString(kActionKey); + std::string json; + base::JSONWriter::Write(payload, &json); + if (!action || action->empty()) { + VLOG(1) << "No action"; + return std::nullopt; + } + const auto action_config = server_config.source_map_actions.find(*action); + if (action_config == server_config.source_map_actions.end()) { + VLOG(1) << "No action config for " << action; + return std::nullopt; + } + const auto* inner_payload = payload.FindDict(kInnerPayloadKey); + if (!inner_payload) { + VLOG(1) << "No inner payload"; + return std::nullopt; + } + base::Value::List tag_list; + tag_list.Append(*action); + tag_list.Append(static_cast(action_config->second->period)); + tag_list.Append(static_cast(action_config->second->limit)); + + base::Value::List key_values; + for (const auto& key : action_config->second->keys) { + auto parts = base::SplitStringUsingSubstr( + key, "->", base::WhitespaceHandling::TRIM_WHITESPACE, + base::SPLIT_WANT_ALL); + if (parts.empty()) { + continue; + } + base::Value value; + if (parts[0].empty()) { + value = base::Value(inner_payload->Clone()); + } else if (const auto* found_value = + inner_payload->FindByDottedPath(parts[0])) { + value = found_value->Clone(); + } + if (parts.size() > 1) { + if (parts[1] == kUrlNormalizationFunc) { + value = CleanURL(regex_util, value); + } else if (parts[1] == kFlattenObjNormalizationFunc) { + value = FlattenObject(value); + } + } + key_values.Append(std::move(value)); + } + + auto period_hours = GetPeriodHoursSinceEpoch(action_config->second->period); + tag_list.Append(std::move(key_values)); + tag_list.Append(period_hours); + + std::string interim_tag_json; + if (!base::JSONWriter::Write(base::Value(tag_list.Clone()), + &interim_tag_json)) { + return std::nullopt; + } + auto count_tag_hash = base::PersistentHash(interim_tag_json); + auto basename_count = GetBasenameCount(profile_prefs, count_tag_hash, + *action_config->second, period_hours); + if (!basename_count) { + VLOG(1) << "No basename count available"; + return std::nullopt; + } + tag_list.Append(*basename_count); + + std::string tag_json; + if (!base::JSONWriter::Write(base::Value(std::move(tag_list)), &tag_json)) { + return std::nullopt; + } + + auto tag_hash = crypto::SHA256HashString(tag_json); + std::vector tag_hash_vector(tag_hash.begin(), tag_hash.end()); + return std::make_optional(tag_hash_vector, *basename_count, + count_tag_hash); +} + +void SaveBasenameCount(PrefService* profile_prefs, + uint32_t count_tag_hash, + size_t count) { + ScopedDictPrefUpdate update(profile_prefs, kUsedBasenameCounts); + + auto count_tag_hash_str = base::NumberToString(count_tag_hash); + auto* count_dict = update->EnsureDict(count_tag_hash_str); + + auto* used_counts_list = count_dict->EnsureList(kUsedCountsKey); + used_counts_list->Append(static_cast(count)); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/signature_basename.h b/components/web_discovery/browser/signature_basename.h new file mode 100644 index 000000000000..1a07cf027e25 --- /dev/null +++ b/components/web_discovery/browser/signature_basename.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#ifndef BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_SIGNATURE_BASENAME_H_ +#define BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_SIGNATURE_BASENAME_H_ + +#include +#include + +#include "base/values.h" +#include "brave/components/web_discovery/browser/regex_util.h" +#include "brave/components/web_discovery/browser/server_config_loader.h" + +class PrefService; + +namespace web_discovery { + +struct BasenameResult { + BasenameResult(std::vector basename, + size_t count, + uint32_t count_tag_hash); + ~BasenameResult(); + + BasenameResult(const BasenameResult&) = delete; + BasenameResult& operator=(const BasenameResult&) = delete; + + std::vector basename; + // The count index for a given "pre-tag". It should be under the limit for a + // given action + size_t count; + uint32_t count_tag_hash; +}; + +// Generates a basename used for the signature. The basename is a sha hash +// of the message "action" (i.e. "query"), the settings for that action +// (defined in the server's "source map"), cherry-picked attributes from the +// payload and the count index for the given message. The count will be under +// the limit defined for the action; the function will return nullopt if the +// limit for the action is exceeded. +std::optional GenerateBasename( + PrefService* profile_prefs, + const ServerConfig& server_config, + RegexUtil& regex_util, + const base::Value::Dict& payload); + +// Saves the count returned from `GenerateBasename` in the prefs. +// This ensures that the count index cannot be used for future messages +// within the defined action limit period (default is 24 hours). +// This should be called after a submission is successfully sent to +// the server. +void SaveBasenameCount(PrefService* profile_prefs, + uint32_t count_tag_hash, + size_t count); + +} // namespace web_discovery + +#endif // BRAVE_COMPONENTS_WEB_DISCOVERY_BROWSER_SIGNATURE_BASENAME_H_ diff --git a/components/web_discovery/browser/signature_basename_unittest.cc b/components/web_discovery/browser/signature_basename_unittest.cc new file mode 100644 index 000000000000..7f995a49f7fd --- /dev/null +++ b/components/web_discovery/browser/signature_basename_unittest.cc @@ -0,0 +1,249 @@ +/* Copyright (c) 2024 The Brave Authors. All rights reserved. + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#include "brave/components/web_discovery/browser/signature_basename.h" + +#include +#include +#include + +#include "base/json/json_reader.h" +#include "base/json/json_writer.h" +#include "base/test/task_environment.h" +#include "brave/components/web_discovery/browser/regex_util.h" +#include "brave/components/web_discovery/browser/server_config_loader.h" +#include "brave/components/web_discovery/browser/web_discovery_service.h" +#include "components/prefs/testing_pref_service.h" +#include "crypto/sha2.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace web_discovery { + +namespace { + +constexpr size_t kMsInHour = 60 * 60 * 1000; + +int GetPeriodHoursSinceEpoch(size_t period_hours) { + auto hours_since_epoch = + base::Time::Now().InMillisecondsSinceUnixEpoch() / kMsInHour; + auto epoch_period_hours = period_hours * (hours_since_epoch / period_hours); + return epoch_period_hours; +} + +base::TimeDelta TimeUntilNextPeriod(int period, int epoch_period_hours) { + return base::Time::FromMillisecondsSinceUnixEpoch( + (epoch_period_hours + period) * static_cast(kMsInHour)) - + base::Time::Now(); +} + +std::vector GenerateExpectedBasename(std::string action, + int period, + int limit, + base::Value::List key_list, + size_t actual_count, + int epoch_period_hours) { + base::Value::List expected_tag_list; + expected_tag_list.Append(action); + expected_tag_list.Append(period); + expected_tag_list.Append(limit); + expected_tag_list.Append(std::move(key_list)); + expected_tag_list.Append(static_cast(epoch_period_hours)); + expected_tag_list.Append(static_cast(actual_count)); + + std::string tag_json; + EXPECT_TRUE(base::JSONWriter::Write(base::Value(std::move(expected_tag_list)), + &tag_json)); + + auto tag_hash = crypto::SHA256HashString(tag_json); + return std::vector(tag_hash.begin(), tag_hash.end()); +} + +base::Value::Dict GeneratePayload(std::string action, + base::Value::Dict inner_payload) { + base::Value::Dict payload; + payload.Set("action", action); + payload.Set("payload", std::move(inner_payload)); + return payload; +} + +} // namespace + +class WebDiscoverySignatureBasenameTest : public testing::Test { + public: + WebDiscoverySignatureBasenameTest() + : task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME) {} + ~WebDiscoverySignatureBasenameTest() override = default; + + // testing::Test: + void SetUp() override { + WebDiscoveryService::RegisterProfilePrefs(profile_prefs_.registry()); + + auto action_config = std::make_unique(); + action_config->keys.push_back("q->url"); + action_config->period = 24; + action_config->limit = 3; + server_config_.source_map_actions["query"] = std::move(action_config); + + action_config = std::make_unique(); + action_config->keys.push_back("field->obj"); + action_config->period = 12; + action_config->limit = 1; + server_config_.source_map_actions["img"] = std::move(action_config); + + action_config = std::make_unique(); + action_config->keys.push_back("field"); + action_config->period = 12; + action_config->limit = 1; + server_config_.source_map_actions["basic"] = std::move(action_config); + } + + protected: + base::test::TaskEnvironment task_environment_; + ServerConfig server_config_; + TestingPrefServiceSimple profile_prefs_; + RegexUtil regex_util_; +}; + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameForURL) { + base::flat_set used_counts; + + base::Value::List key_list; + key_list.Append("examplecomtesttestpage"); + + base::Value::Dict inner_payload; + inner_payload.Set("q", "https://www.EXample.com/test test/page"); + auto payload = GeneratePayload("query", std::move(inner_payload)); + + auto epoch_period_hours = GetPeriodHoursSinceEpoch(24); + for (size_t i = 0; i < 3; i++) { + auto actual_basename = + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload); + ASSERT_TRUE(actual_basename); + EXPECT_LT(actual_basename->count, 3u); + EXPECT_FALSE(used_counts.contains(actual_basename->count)); + + auto expected_basename = + GenerateExpectedBasename("query", 24, 3, key_list.Clone(), + actual_basename->count, epoch_period_hours); + + EXPECT_EQ(actual_basename->basename, expected_basename); + used_counts.insert(actual_basename->count); + + SaveBasenameCount(&profile_prefs_, actual_basename->count_tag_hash, + actual_basename->count); + } + + EXPECT_FALSE( + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload)); +} + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameNotSaved) { + base::Value::Dict inner_payload; + inner_payload.Set("q", "https://www.example.com/test/page"); + auto payload = GeneratePayload("query", std::move(inner_payload)); + + for (size_t i = 0; i < 10; i++) { + EXPECT_TRUE(GenerateBasename(&profile_prefs_, server_config_, regex_util_, + payload)); + } +} + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameLimitExpiry) { + base::Value::Dict inner_payload; + inner_payload.Set("q", "https://www.example.com/test/page"); + auto payload = GeneratePayload("query", std::move(inner_payload)); + + for (size_t i = 0; i < 3; i++) { + auto epoch_period_hours = GetPeriodHoursSinceEpoch(24); + for (size_t j = 0; j < 3; j++) { + auto basename = GenerateBasename(&profile_prefs_, server_config_, + regex_util_, payload); + ASSERT_TRUE(basename); + SaveBasenameCount(&profile_prefs_, basename->count_tag_hash, + basename->count); + } + + auto time_until_next_period = TimeUntilNextPeriod(24, epoch_period_hours); + task_environment_.AdvanceClock(time_until_next_period / 2); + EXPECT_FALSE(GenerateBasename(&profile_prefs_, server_config_, regex_util_, + payload)); + task_environment_.AdvanceClock(time_until_next_period / 2); + } +} + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameForFlattenedObj) { + auto field_obj = base::JSONReader::Read(R"({ + "this": { + "is": { + "test": "object" + } + }, + "example1": [ 1, 2 ], + "example2": { "abc": "def" } + })"); + ASSERT_TRUE(field_obj); + auto expected_flattened_obj = base::JSONReader::Read(R"([ + [ + [["example1", "0"], 1], + [["example1", "1"], 2], + [["example2", "abc"], "def"], + [["this", "is", "test"], "object"] + ] + ])"); + ASSERT_TRUE(expected_flattened_obj); + + base::Value::Dict inner_payload; + inner_payload.Set("field", std::move(*field_obj)); + auto payload = GeneratePayload("img", std::move(inner_payload)); + + auto actual_basename = + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload); + ASSERT_TRUE(actual_basename); + EXPECT_EQ(actual_basename->count, 0u); + + auto epoch_period_hours = GetPeriodHoursSinceEpoch(24); + auto expected_basename = GenerateExpectedBasename( + "img", 12, 1, expected_flattened_obj->GetList().Clone(), 0u, + epoch_period_hours); + + EXPECT_EQ(actual_basename->basename, expected_basename); + + SaveBasenameCount(&profile_prefs_, actual_basename->count_tag_hash, + actual_basename->count); + + EXPECT_FALSE( + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload)); +} + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameSimple) { + base::Value::List key_list; + key_list.Append("test"); + + base::Value::Dict inner_payload; + inner_payload.Set("field", "test"); + auto payload = GeneratePayload("basic", std::move(inner_payload)); + + auto actual_basename = + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload); + ASSERT_TRUE(actual_basename); + EXPECT_EQ(actual_basename->count, 0u); + + auto epoch_period_hours = GetPeriodHoursSinceEpoch(24); + auto expected_basename = GenerateExpectedBasename( + "basic", 12, 1, std::move(key_list), 0u, epoch_period_hours); + + EXPECT_EQ(actual_basename->basename, expected_basename); +} + +TEST_F(WebDiscoverySignatureBasenameTest, BasenameNoAction) { + base::Value::Dict inner_payload; + inner_payload.Set("field", "test"); + auto payload = GeneratePayload("bad_action", std::move(inner_payload)); + + ASSERT_FALSE( + GenerateBasename(&profile_prefs_, server_config_, regex_util_, payload)); +} + +} // namespace web_discovery diff --git a/components/web_discovery/browser/web_discovery_service.cc b/components/web_discovery/browser/web_discovery_service.cc index f4e03a43243a..5e6d11362a7b 100644 --- a/components/web_discovery/browser/web_discovery_service.cc +++ b/components/web_discovery/browser/web_discovery_service.cc @@ -28,6 +28,11 @@ namespace web_discovery { +namespace { +constexpr base::TimeDelta kAliveCheckInterval = base::Minutes(1); +constexpr size_t kMinPageCountForAliveMessage = 2; +} // namespace + WebDiscoveryService::WebDiscoveryService( PrefService* local_state, PrefService* profile_prefs, @@ -67,6 +72,10 @@ void WebDiscoveryService::RegisterProfilePrefs(PrefRegistrySimple* registry) { registry->RegisterDictionaryPref(kAnonymousCredentialsDict); registry->RegisterStringPref(kCredentialRSAPrivateKey, {}); registry->RegisterStringPref(kCredentialRSAPublicKey, {}); + registry->RegisterListPref(kScheduledDoubleFetches); + registry->RegisterListPref(kScheduledReports); + registry->RegisterDictionaryPref(kUsedBasenameCounts); + registry->RegisterDictionaryPref(kPageCounts); } void WebDiscoveryService::SetExtensionPrefIfNativeDisabled( @@ -97,6 +106,9 @@ void WebDiscoveryService::Start() { } void WebDiscoveryService::Stop() { + alive_message_timer_.Stop(); + reporter_ = nullptr; + double_fetcher_ = nullptr; content_scraper_ = nullptr; server_config_loader_ = nullptr; credential_manager_ = nullptr; @@ -105,6 +117,10 @@ void WebDiscoveryService::Stop() { profile_prefs_->ClearPref(kAnonymousCredentialsDict); profile_prefs_->ClearPref(kCredentialRSAPrivateKey); profile_prefs_->ClearPref(kCredentialRSAPublicKey); + profile_prefs_->ClearPref(kScheduledDoubleFetches); + profile_prefs_->ClearPref(kScheduledReports); + profile_prefs_->ClearPref(kUsedBasenameCounts); + profile_prefs_->ClearPref(kPageCounts); } void WebDiscoveryService::OnEnabledChange() { @@ -124,6 +140,35 @@ void WebDiscoveryService::OnPatternsLoaded() { content_scraper_ = std::make_unique( server_config_loader_.get(), ®ex_util_); } + if (!double_fetcher_) { + double_fetcher_ = std::make_unique( + profile_prefs_.get(), shared_url_loader_factory_.get(), + base::BindRepeating(&WebDiscoveryService::OnDoubleFetched, + base::Unretained(this))); + } + if (!reporter_) { + reporter_ = std::make_unique( + profile_prefs_.get(), shared_url_loader_factory_.get(), + credential_manager_.get(), ®ex_util_, server_config_loader_.get()); + } + MaybeSendAliveMessage(); +} + +void WebDiscoveryService::OnDoubleFetched( + const GURL& url, + const base::Value& associated_data, + std::optional response_body) { + if (!response_body) { + return; + } + auto prev_scrape_result = PageScrapeResult::FromValue(associated_data); + if (!prev_scrape_result) { + return; + } + content_scraper_->ParseAndScrapePage( + url, true, std::move(prev_scrape_result), *response_body, + base::BindOnce(&WebDiscoveryService::OnContentScraped, + base::Unretained(this), true)); } void WebDiscoveryService::DidFinishLoad( @@ -135,6 +180,14 @@ void WebDiscoveryService::DidFinishLoad( const auto* matching_url_details = server_config_loader_->GetLastPatterns().GetMatchingURLPattern(url, false); + if (!matching_url_details || !matching_url_details->is_search_engine) { + if (!current_page_count_hour_key_.empty()) { + ScopedDictPrefUpdate page_count_update(profile_prefs_, kPageCounts); + auto existing_count = + page_count_update->FindInt(current_page_count_hour_key_).value_or(0); + page_count_update->Set(current_page_count_hour_key_, existing_count + 1); + } + } if (!matching_url_details) { return; } @@ -164,9 +217,73 @@ void WebDiscoveryService::OnContentScraped( if (!original_url_details) { return; } + if (!is_strict && original_url_details->is_search_engine) { + auto* strict_url_details = + patterns.GetMatchingURLPattern(result->url, true); + if (strict_url_details) { + auto url = result->url; + if (!result->query) { + return; + } + if (IsPrivateQueryLikely(regex_util_, *result->query)) { + return; + } + url = GeneratePrivateSearchURL(url, *result->query, *strict_url_details); + VLOG(1) << "Double fetching search page: " << url; + double_fetcher_->ScheduleDoubleFetch(url, result->SerializeToValue()); + } + } auto payloads = GenerateQueryPayloads( server_config_loader_->GetLastServerConfig(), regex_util_, original_url_details, std::move(result)); + for (auto& payload : payloads) { + reporter_->ScheduleSend(std::move(payload)); + } +} + +bool WebDiscoveryService::UpdatePageCountStartTime() { + auto now = base::Time::Now(); + if (!current_page_count_start_time_.is_null() && + (now - current_page_count_start_time_) < base::Hours(1)) { + return false; + } + base::Time::Exploded exploded; + now.UTCExplode(&exploded); + exploded.millisecond = 0; + exploded.second = 0; + exploded.minute = 0; + if (!base::Time::FromUTCExploded(exploded, ¤t_page_count_start_time_)) { + return false; + } + current_page_count_hour_key_ = + base::StringPrintf("%04d%02d%02d%02d", exploded.year, exploded.month, + exploded.day_of_month, exploded.hour); + return true; +} + +void WebDiscoveryService::MaybeSendAliveMessage() { + if (!alive_message_timer_.IsRunning()) { + alive_message_timer_.Start( + FROM_HERE, kAliveCheckInterval, + base::BindRepeating(&WebDiscoveryService::MaybeSendAliveMessage, + base::Unretained(this))); + } + if (!UpdatePageCountStartTime()) { + return; + } + ScopedDictPrefUpdate update(profile_prefs_, kPageCounts); + for (auto it = update->begin(); it != update->end();) { + if (it->first == current_page_count_hour_key_) { + it++; + continue; + } + if (it->second.is_int() && static_cast(it->second.GetInt()) >= + kMinPageCountForAliveMessage) { + reporter_->ScheduleSend(GenerateAlivePayload( + server_config_loader_->GetLastServerConfig(), it->first)); + } + it = update->erase(it); + } } } // namespace web_discovery diff --git a/components/web_discovery/browser/web_discovery_service.h b/components/web_discovery/browser/web_discovery_service.h index 8c9f15cefa5a..e0e476359437 100644 --- a/components/web_discovery/browser/web_discovery_service.h +++ b/components/web_discovery/browser/web_discovery_service.h @@ -14,7 +14,9 @@ #include "base/memory/raw_ptr.h" #include "brave/components/web_discovery/browser/content_scraper.h" #include "brave/components/web_discovery/browser/credential_manager.h" +#include "brave/components/web_discovery/browser/double_fetcher.h" #include "brave/components/web_discovery/browser/regex_util.h" +#include "brave/components/web_discovery/browser/reporter.h" #include "brave/components/web_discovery/browser/server_config_loader.h" #include "components/keyed_service/core/keyed_service.h" #include "components/prefs/pref_change_registrar.h" @@ -69,6 +71,12 @@ class WebDiscoveryService : public KeyedService { void OnPatternsLoaded(); void OnContentScraped(bool is_strict, std::unique_ptr result); + void OnDoubleFetched(const GURL& url, + const base::Value& associated_data, + std::optional response_body); + + bool UpdatePageCountStartTime(); + void MaybeSendAliveMessage(); raw_ptr local_state_; raw_ptr profile_prefs_; @@ -85,6 +93,12 @@ class WebDiscoveryService : public KeyedService { std::unique_ptr server_config_loader_; std::unique_ptr credential_manager_; std::unique_ptr content_scraper_; + std::unique_ptr double_fetcher_; + std::unique_ptr reporter_; + + base::Time current_page_count_start_time_; + std::string current_page_count_hour_key_; + base::RepeatingTimer alive_message_timer_; }; } // namespace web_discovery