diff --git a/corehq/apps/geospatial/const.py b/corehq/apps/geospatial/const.py index 56a79349f06b..087d6d5fd053 100644 --- a/corehq/apps/geospatial/const.py +++ b/corehq/apps/geospatial/const.py @@ -140,3 +140,5 @@ DEFAULT_QUERY_LIMIT = 10_000 DEFAULT_CHUNK_SIZE = 100 + +MAPBOX_DIRECTIONS_MATRIX_API_COORDINATES_LIMIT = 25 diff --git a/corehq/apps/geospatial/management/commands/poc_road_network_disbursement_sequential.py b/corehq/apps/geospatial/management/commands/poc_road_network_disbursement_sequential.py new file mode 100644 index 000000000000..0504f281ba0d --- /dev/null +++ b/corehq/apps/geospatial/management/commands/poc_road_network_disbursement_sequential.py @@ -0,0 +1,148 @@ +import math +import time +from itertools import islice + +from django.core.management import BaseCommand + +from jsonobject.exceptions import BadValueError +from sklearn.cluster import KMeans +import numpy as np + +from corehq.apps.es import CaseSearchES +from corehq.apps.es.case_search import case_property_missing, wrap_case_search_hit +from corehq.apps.es.users import missing_or_empty_user_data_property +from corehq.apps.geospatial.utils import get_geo_case_property, get_geo_user_property +from corehq.apps.geospatial.tasks import clusters_disbursement_task +from corehq.apps.users.models import CouchUser, CommCareUser +from couchforms.geopoint import GeoPoint +from dimagi.utils.couch.database import iter_docs + +ES_QUERY_CHUNK_SIZE = 10000 + + +class Command(BaseCommand): + help = ('(POC) Test performance of road network disbursement algorithm using k-cluster and ' + 'sequential approach for mapbox API limit (60 requests/min)') + + def add_arguments(self, parser): + parser.add_argument('domain') + parser.add_argument('--cluster_chunk_size', required=False, default=10000, type=int) + parser.add_argument('--dry_run', action='store_true', help="skips running the disbursement task") + parser.add_argument( + '--cluster_solve_percent', + required=False, + default=10, + type=int, + help="solves disbursement for percent of clusters specified", + ) + + def handle(self, *args, **options): + domain = options['domain'] + cluster_chunk_size = options['cluster_chunk_size'] + print(f"Cluster chunk size: {cluster_chunk_size}") + + geo_case_property = get_geo_case_property(domain) + + gps_users_data = self.get_users_with_gps(domain) + print(f"Total GPS Mobile workers: {len(gps_users_data)}") + + total_cases = CaseSearchES().domain(domain).NOT(case_property_missing(geo_case_property)).count() + print(f"Total GPS Cases: {total_cases}") + cases_data = [] + batch_count = math.ceil(total_cases / ES_QUERY_CHUNK_SIZE) + for i in range(batch_count): + print(f"Fetching Cases: Processing Batch {i + 1} of {batch_count}...") + cases_data.extend( + self.get_cases_with_gps(domain, geo_case_property, offset=i * ES_QUERY_CHUNK_SIZE) + ) + print("All cases fetched successfully") + + start_time = time.time() + n_clusters = max(len(gps_users_data), len(cases_data)) // cluster_chunk_size + 1 + print(f"Creating {n_clusters} clusters for {len(gps_users_data)} users and {len(cases_data)} cases...") + clusters = self.create_clusters(gps_users_data, cases_data, n_clusters) + print(f"Time taken for creating clusters: {time.time() - start_time}") + + if not options['dry_run']: + cluster_solve_percent = options['cluster_solve_percent'] + number_of_clusters_to_disburse = int(cluster_solve_percent / 100 * len(clusters)) + clusters_to_disburse = dict(islice(clusters.items(), number_of_clusters_to_disburse)) + clusters_disbursement_task.delay(domain, clusters_to_disburse) + + def get_users_with_gps(self, domain): + """Mostly copied over from corehq.apps.geospatial.views.get_users_with_gps""" + location_prop_name = get_geo_user_property(domain) + from corehq.apps.es import UserES + query = ( + UserES() + .domain(domain) + .mobile_users() + .NOT(missing_or_empty_user_data_property(location_prop_name)) + .fields(['location_id', '_id']) + ) + + user_ids = [] + for user_doc in query.run().hits: + user_ids.append(user_doc['_id']) + + users = map(CouchUser.wrap_correctly, iter_docs(CommCareUser.get_db(), user_ids)) + users_data = [] + for user in users: + location = user.get_user_data(domain).get(location_prop_name, '') + coordinates = self._get_location_from_string(location) if location else None + if coordinates: + users_data.append( + { + 'id': user.user_id, + 'lon': coordinates['lng'], + 'lat': coordinates['lat'], + } + ) + return users_data + + def _get_location_from_string(self, data): + try: + geo_point = GeoPoint.from_string(data, flexible=True) + return {"lat": geo_point.latitude, "lng": geo_point.longitude} + except BadValueError: + return None + + def get_cases_with_gps(self, domain, geo_case_property, offset): + query = CaseSearchES().domain(domain).size(ES_QUERY_CHUNK_SIZE).start(offset) + query = query.NOT(case_property_missing(geo_case_property)) + + cases_data = [] + for row in query.run().raw['hits'].get('hits', []): + case = wrap_case_search_hit(row) + coordinates = self.get_case_geo_location(case, geo_case_property) + if coordinates: + cases_data.append({ + 'id': case.case_id, + 'lon': coordinates['lng'], + 'lat': coordinates['lat'], + }) + return cases_data + + def get_case_geo_location(self, case, geo_case_property): + geo_point = case.get_case_property(geo_case_property) + return self._get_location_from_string(geo_point) + + def create_clusters(self, users, cases, n_clusters): + """ + Uses k-means clustering to return a dictionary of ``n_clusters`` + number of clusters of users and cases based on their coordinates. + """ + n_users = len(users) + locations = users + cases + coordinates = np.array([[loc['lat'], loc['lon']] for loc in locations]) + kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(coordinates) + clusters = {i: {'users': [], 'cases': []} for i in range(n_clusters)} + for idx, label in enumerate(kmeans.labels_): + if idx < n_users: + clusters[label]['users'].append(users[idx]) + else: + clusters[label]['cases'].append(cases[idx - n_users]) + for key in clusters.keys(): + print(f"cluster index: {key}, users: {len(clusters[key]['users'])}," + f" cases: {len(clusters[key]['cases'])}") + return clusters diff --git a/corehq/apps/geospatial/routing_solvers/pulp.py b/corehq/apps/geospatial/routing_solvers/pulp.py index 0b11f49ad694..cd570eed52d4 100644 --- a/corehq/apps/geospatial/routing_solvers/pulp.py +++ b/corehq/apps/geospatial/routing_solvers/pulp.py @@ -1,11 +1,16 @@ +import time + import haversine import requests import pulp import copy from dataclasses import dataclass + +from dimagi.utils.chunked import chunked from .mapbox_utils import validate_routing_request from corehq.apps.geospatial.routing_solvers.base import DisbursementAlgorithmSolverInterface +from ..const import MAPBOX_DIRECTIONS_MATRIX_API_COORDINATES_LIMIT @dataclass @@ -184,40 +189,66 @@ class RoadNetworkSolver(RadialDistanceSolver): """ Solves user-case location assignment based on driving distance """ - def calculate_distance_matrix(self, config): - # Todo; support more than Mapbox limit by chunking - if len(self.user_locations + self.case_locations) > 25: - raise Exception("This is more than Mapbox matrix API limit (25)") - - coordinates = ';'.join([ - f'{loc["lon"]},{loc["lat"]}' - for loc in self.user_locations + self.case_locations] - ) - sources_count = len(self.user_locations) - destinations_count = len(self.case_locations) - - sources = ";".join(map(str, list(range(sources_count)))) - destinations = ";".join(map(str, list(range(sources_count, sources_count + destinations_count)))) - - url = f'https://api.mapbox.com/directions-matrix/v1/mapbox/{config.travel_mode}/{coordinates}' - - if config.max_case_travel_time: - annotations = "distance,duration" - else: - annotations = "distance" - - params = { - 'sources': sources, - 'destinations': destinations, - 'annotations': annotations, - 'access_token': config.plaintext_api_token, - } - - response = requests.get(url, params=params) - response.raise_for_status() - - return self.sanitize_response(response.json()) + # We need at least one case along with users, hence the below limit for users. + if len(self.user_locations) > (MAPBOX_DIRECTIONS_MATRIX_API_COORDINATES_LIMIT - 1): + raise Exception(f"Error: Users count for cluster exceeds the limit of " + f"{MAPBOX_DIRECTIONS_MATRIX_API_COORDINATES_LIMIT - 1}") + + cases_chunk_size = MAPBOX_DIRECTIONS_MATRIX_API_COORDINATES_LIMIT - len(self.user_locations) + result = {} + count = 1 + for case_locations_chunk in chunked(self.case_locations, cases_chunk_size): + print(f"Fetching Distance Matrix for chunk: {count}...") + start_time = time.time() + case_locations_chunk = list(case_locations_chunk) + coordinates = ';'.join([ + f'{loc["lon"]},{loc["lat"]}' + for loc in self.user_locations + case_locations_chunk] + ) + sources_count = len(self.user_locations) + destinations_count = len(case_locations_chunk) + + sources = ";".join(map(str, list(range(sources_count)))) + destinations = ";".join(map(str, list(range(sources_count, sources_count + destinations_count)))) + + url = f'https://api.mapbox.com/directions-matrix/v1/mapbox/{config.travel_mode}/{coordinates}' + + if config.max_case_travel_time: + annotations = "distance,duration" + else: + annotations = "distance" + + params = { + 'sources': sources, + 'destinations': destinations, + 'annotations': annotations, + 'access_token': config.plaintext_api_token, + } + + response = requests.get(url, params=params) + response.raise_for_status() + + if not result: + result = response.json() + else: + self.append_chunk_result(result, response.json()) + + count = count + 1 + print("Distance Matrix fetched successfully...") + # Avoid Mapbox rate limit of 60 requests per minute + time_elapsed = time.time() - start_time + if time_elapsed < 1: + time.sleep(1 - time_elapsed) + + return self.sanitize_response(result) + + def append_chunk_result(self, result, chunk_result): + for idx, row in enumerate(result['distances']): + row.extend(chunk_result['distances'][idx]) + if result.get('durations'): + for idx, row in enumerate(result['durations']): + row.extend(chunk_result['durations'][idx]) def sanitize_response(self, response): distances_km = self._convert_m_to_km(response['distances']) diff --git a/corehq/apps/geospatial/tasks.py b/corehq/apps/geospatial/tasks.py index ef7e6856071f..bd396ec31854 100644 --- a/corehq/apps/geospatial/tasks.py +++ b/corehq/apps/geospatial/tasks.py @@ -1,5 +1,8 @@ import math +import time +from corehq.apps.geospatial.models import GeoConfig +from corehq.apps.geospatial.routing_solvers.pulp import RoadNetworkSolver from dimagi.utils.logging import notify_exception from corehq.apps.celery import task @@ -68,3 +71,32 @@ def index_es_docs_with_location_props(domain): ) else: celery_task_tracker.mark_completed() + + +@task(queue='geospatial_queue', ignore_result=True) +def clusters_disbursement_task(domain, clusters): + config = GeoConfig.objects.get(domain=domain) + + print(f"Processing disbursement for {len(clusters)} clusters ...") + start_time = time.time() + assignments = [] + for cluster_id in clusters.keys(): + users_chunk = clusters[cluster_id]['users'] + cases_chunk = clusters[cluster_id]['cases'] + if users_chunk and cases_chunk: + print(f"Starting disbursement for cluster: {cluster_id}, total users: {len(users_chunk)}," + f" total cases: {len(cases_chunk)}") + try: + solver = RoadNetworkSolver(clusters[cluster_id]) + result = solver.solve(config) + assignments.append(result) + except Exception as e: + print(f"Error occurred for disbursement for cluster: {cluster_id} : {str(e)}") + continue + print(f"Completed disbursement for cluster: {cluster_id}") + elif users_chunk: + print(f"No cases available for mobile workers in cluster: {cluster_id}") + elif cases_chunk: + print(f"No mobile workers available for cases in cluster: {cluster_id}") + print(f"Total Time for solving disbursements: {time.time() - start_time}") + return assignments diff --git a/requirements/base-requirements.in b/requirements/base-requirements.in index 6444fc9bcd20..8e8d78dddf98 100644 --- a/requirements/base-requirements.in +++ b/requirements/base-requirements.in @@ -66,6 +66,7 @@ kafka-python looseversion lxml markdown +numpy # Temporary requirement for POC. oic pulp # Used in Geospatial features to solve routing problems - SolTech openpyxl @@ -103,6 +104,7 @@ sh simpleeval @ git+https://github.com/dimagi/simpleeval.git@d85c5a9f972c0f0416a1716bb06d1a3ebc83e7ec simplejson six +scikit-learn # Temporary requirement for POC. socketpool sqlagg SQLAlchemy diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index b5bf14589a69..edae32be8787 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -381,6 +381,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.4.2 + # via scikit-learn jsonfield==3.1.0 # via # -r base-requirements.in @@ -445,6 +447,11 @@ multidict==6.0.5 # yarl myst-parser==2.0.0 # via -r docs-requirements.in +numpy==2.0.2 + # via + # -r base-requirements.in + # scikit-learn + # scipy oauthlib==3.1.0 # via # django-oauth-toolkit @@ -670,6 +677,10 @@ s3transfer==0.6.0 # via boto3 schema==0.7.5 # via -r base-requirements.in +scikit-learn==1.6.0 + # via -r base-requirements.in +scipy==1.13.1 + # via scikit-learn sentry-sdk==2.8.0 # via -r base-requirements.in sh==2.0.3 @@ -767,6 +778,8 @@ testil==1.1 # via -r test-requirements.in text-unidecode==1.3 # via -r base-requirements.in +threadpoolctl==3.5.0 + # via scikit-learn tinycss2==1.2.1 # via bleach tomli==2.0.1 diff --git a/requirements/docs-requirements.txt b/requirements/docs-requirements.txt index a4a696f124f3..339f79f42c7e 100644 --- a/requirements/docs-requirements.txt +++ b/requirements/docs-requirements.txt @@ -329,6 +329,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.4.2 + # via scikit-learn jsonfield==3.1.0 # via # -r base-requirements.in @@ -382,6 +384,11 @@ multidict==6.0.5 # yarl myst-parser==2.0.0 # via -r docs-requirements.in +numpy==2.0.2 + # via + # -r base-requirements.in + # scikit-learn + # scipy oauthlib==3.1.0 # via # django-oauth-toolkit @@ -561,6 +568,10 @@ s3transfer==0.6.0 # via boto3 schema==0.7.5 # via -r base-requirements.in +scikit-learn==1.6.0 + # via -r base-requirements.in +scipy==1.13.1 + # via scikit-learn sentry-sdk==2.8.0 # via -r base-requirements.in sh==2.0.3 @@ -643,6 +654,8 @@ suds-py3==1.4.5.0 # via -r base-requirements.in text-unidecode==1.3 # via -r base-requirements.in +threadpoolctl==3.5.0 + # via scikit-learn tinycss2==1.2.1 # via bleach toolz==0.12.1 diff --git a/requirements/prod-requirements.txt b/requirements/prod-requirements.txt index 34de55728e29..b5152ee65487 100644 --- a/requirements/prod-requirements.txt +++ b/requirements/prod-requirements.txt @@ -330,6 +330,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.4.2 + # via scikit-learn jsonfield==3.1.0 # via # -r base-requirements.in @@ -379,6 +381,11 @@ multidict==6.0.5 # via # aiohttp # yarl +numpy==2.0.2 + # via + # -r base-requirements.in + # scikit-learn + # scipy oauthlib==3.1.0 # via # django-oauth-toolkit @@ -568,6 +575,10 @@ s3transfer==0.6.0 # via boto3 schema==0.7.5 # via -r base-requirements.in +scikit-learn==1.6.0 + # via -r base-requirements.in +scipy==1.13.1 + # via scikit-learn sentry-sdk==2.8.0 # via -r base-requirements.in setproctitle==1.2.2 @@ -628,6 +639,8 @@ suds-py3==1.4.5.0 # via -r base-requirements.in text-unidecode==1.3 # via -r base-requirements.in +threadpoolctl==3.5.0 + # via scikit-learn tinycss2==1.2.1 # via bleach toolz==0.12.1 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 79533c61769c..f1fd98f5671a 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -312,6 +312,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.4.2 + # via scikit-learn jsonfield==3.1.0 # via # -r base-requirements.in @@ -359,6 +361,11 @@ multidict==6.0.5 # via # aiohttp # yarl +numpy==2.0.2 + # via + # -r base-requirements.in + # scikit-learn + # scipy oauthlib==3.1.0 # via # django-oauth-toolkit @@ -531,6 +538,10 @@ s3transfer==0.6.0 # via boto3 schema==0.7.5 # via -r base-requirements.in +scikit-learn==1.6.0 + # via -r base-requirements.in +scipy==1.13.1 + # via scikit-learn sentry-sdk==2.8.0 # via -r base-requirements.in sh==2.0.3 @@ -587,6 +598,8 @@ suds-py3==1.4.5.0 # via -r base-requirements.in text-unidecode==1.3 # via -r base-requirements.in +threadpoolctl==3.5.0 + # via scikit-learn tinycss2==1.2.1 # via bleach toolz==0.12.1 diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index c4b3141ec372..ecd2f051043a 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -332,6 +332,8 @@ jmespath==0.10.0 # via # boto3 # botocore +joblib==1.4.2 + # via scikit-learn jsonfield==3.1.0 # via # -r base-requirements.in @@ -381,6 +383,11 @@ multidict==6.0.5 # via # aiohttp # yarl +numpy==2.0.2 + # via + # -r base-requirements.in + # scikit-learn + # scipy oauthlib==3.1.0 # via # django-oauth-toolkit @@ -581,6 +588,10 @@ s3transfer==0.6.0 # via boto3 schema==0.7.5 # via -r base-requirements.in +scikit-learn==1.6.0 + # via -r base-requirements.in +scipy==1.13.1 + # via scikit-learn sentry-sdk==2.8.0 # via -r base-requirements.in sh==2.0.3 @@ -645,6 +656,8 @@ testil==1.1 # via -r test-requirements.in text-unidecode==1.3 # via -r base-requirements.in +threadpoolctl==3.5.0 + # via scikit-learn tinycss2==1.2.1 # via bleach tomli==2.0.1