diff --git a/.ds.baseline b/.ds.baseline index 1c279e018..3779d8edb 100644 --- a/.ds.baseline +++ b/.ds.baseline @@ -239,7 +239,7 @@ "filename": "tests/app/dao/test_services_dao.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 265, + "line_number": 289, "is_secret": false } ], @@ -249,7 +249,7 @@ "filename": "tests/app/dao/test_users_dao.py", "hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8", "is_verified": false, - "line_number": 52, + "line_number": 69, "is_secret": false }, { @@ -257,7 +257,7 @@ "filename": "tests/app/dao/test_users_dao.py", "hashed_secret": "f2c57870308dc87f432e5912d4de6f8e322721ba", "is_verified": false, - "line_number": 176, + "line_number": 199, "is_secret": false } ], @@ -384,5 +384,5 @@ } ] }, - "generated_at": "2024-09-27T16:42:53Z" + "generated_at": "2024-10-28T20:26:27Z" } diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 22c7f9c89..bcf0861e4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -63,7 +63,7 @@ jobs: NOTIFY_E2E_TEST_PASSWORD: ${{ secrets.NOTIFY_E2E_TEST_PASSWORD }} - name: Check coverage threshold # TODO get this back up to 95 - run: poetry run coverage report -m --fail-under=91 + run: poetry run coverage report -m --fail-under=93 validate-new-relic-config: runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index 88cf6f814..76c38d94e 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ test: ## Run tests and create coverage report poetry run coverage run --omit=*/migrations/*,*/tests/* -m pytest --maxfail=10 ## TODO set this back to 95 asap - poetry run coverage report -m --fail-under=91 + poetry run coverage report -m --fail-under=93 poetry run coverage html -d .coverage_cache .PHONY: py-lock diff --git a/app/aws/s3.py b/app/aws/s3.py index d4d704632..703b917f0 100644 --- a/app/aws/s3.py +++ b/app/aws/s3.py @@ -476,23 +476,7 @@ def get_personalisation_from_s3(service_id, job_id, job_row_number): set_job_cache(job_cache, f"{job_id}_personalisation", extract_personalisation(job)) - # If we can find the quick dictionary, use it - if job_cache.get(f"{job_id}_personalisation") is not None: - personalisation_to_return = job_cache.get(f"{job_id}_personalisation")[0].get( - job_row_number - ) - if personalisation_to_return: - return personalisation_to_return - else: - current_app.logger.warning( - f"Was unable to retrieve personalisation from lookup dictionary for job {job_id}" - ) - return {} - else: - current_app.logger.error( - f"Was unable to construct lookup dictionary for job {job_id}" - ) - return {} + return job_cache.get(f"{job_id}_personalisation")[0].get(job_row_number) def get_job_metadata_from_s3(service_id, job_id): diff --git a/app/commands.py b/app/commands.py index 45fce9211..5580e7632 100644 --- a/app/commands.py +++ b/app/commands.py @@ -24,12 +24,6 @@ dao_create_or_update_annual_billing_for_year, set_default_free_allowance_for_service, ) -from app.dao.fact_billing_dao import ( - delete_billing_data_for_service_for_day, - fetch_billing_data_for_day, - get_service_ids_that_need_billing_populated, - update_fact_billing, -) from app.dao.jobs_dao import dao_get_job_by_id from app.dao.organization_dao import ( dao_add_service_to_organization, @@ -63,7 +57,7 @@ TemplateHistory, User, ) -from app.utils import get_midnight_in_utc, utc_now +from app.utils import utc_now from notifications_utils.recipients import RecipientCSV from notifications_utils.template import SMSMessageTemplate from tests.app.db import ( @@ -167,6 +161,7 @@ def purge_functional_test_data(user_email_prefix): delete_model_user(usr) +# TODO maintainability what is the purpose of this command? Who would use it and why? @notify_command(name="insert-inbound-numbers") @click.option( "-f", @@ -175,7 +170,6 @@ def purge_functional_test_data(user_email_prefix): help="""Full path of the file to upload, file is a contains inbound numbers, one number per line.""", ) def insert_inbound_numbers_from_file(file_name): - # TODO maintainability what is the purpose of this command? Who would use it and why? current_app.logger.info(f"Inserting inbound numbers from {file_name}") with open(file_name) as file: @@ -195,50 +189,6 @@ def setup_commands(application): application.cli.add_command(command_group) -@notify_command(name="rebuild-ft-billing-for-day") -@click.option("-s", "--service_id", required=False, type=click.UUID) -@click.option( - "-d", - "--day", - help="The date to recalculate, as YYYY-MM-DD", - required=True, - type=click_dt(format="%Y-%m-%d"), -) -def rebuild_ft_billing_for_day(service_id, day): - # TODO maintainability what is the purpose of this command? Who would use it and why? - - """ - Rebuild the data in ft_billing for the given service_id and date - """ - - def rebuild_ft_data(process_day, service): - deleted_rows = delete_billing_data_for_service_for_day(process_day, service) - current_app.logger.info( - f"deleted {deleted_rows} existing billing rows for {service} on {process_day}" - ) - transit_data = fetch_billing_data_for_day( - process_day=process_day, service_id=service - ) - # transit_data = every row that should exist - for data in transit_data: - # upsert existing rows - update_fact_billing(data, process_day) - current_app.logger.info( - f"added/updated {len(transit_data)} billing rows for {service} on {process_day}" - ) - - if service_id: - # confirm the service exists - dao_fetch_service_by_id(service_id) - rebuild_ft_data(day, service_id) - else: - services = get_service_ids_that_need_billing_populated( - get_midnight_in_utc(day), get_midnight_in_utc(day + timedelta(days=1)) - ) - for row in services: - rebuild_ft_data(day, row.service_id) - - @notify_command(name="bulk-invite-user-to-service") @click.option( "-f", @@ -472,31 +422,6 @@ def associate_services_to_organizations(): current_app.logger.info("finished associating services to organizations") -@notify_command(name="populate-service-volume-intentions") -@click.option( - "-f", - "--file_name", - required=True, - help="Pipe delimited file containing service_id, SMS, email", -) -def populate_service_volume_intentions(file_name): - # [0] service_id - # [1] SMS:: volume intentions for service - # [2] Email:: volume intentions for service - - # TODO maintainability what is the purpose of this command? Who would use it and why? - - with open(file_name, "r") as f: - for line in itertools.islice(f, 1, None): - columns = line.split(",") - current_app.logger.info(columns) - service = dao_fetch_service_by_id(columns[0]) - service.volume_sms = columns[1] - service.volume_email = columns[2] - dao_update_service(service) - current_app.logger.info("populate-service-volume-intentions complete") - - @notify_command(name="populate-go-live") @click.option( "-f", "--file_name", required=True, help="CSV file containing live service data" diff --git a/app/dao/notifications_dao.py b/app/dao/notifications_dao.py index f7150d08f..1d07473c1 100644 --- a/app/dao/notifications_dao.py +++ b/app/dao/notifications_dao.py @@ -1,7 +1,7 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import asc, desc, or_, select, text, union +from sqlalchemy import asc, delete, desc, func, or_, select, text, union, update from sqlalchemy.orm import joinedload from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql import functions @@ -109,11 +109,12 @@ def _update_notification_status( def update_notification_status_by_id( notification_id, status, sent_by=None, provider_response=None, carrier=None ): - notification = ( - Notification.query.with_for_update() + stmt = ( + select(Notification) + .with_for_update() .filter(Notification.id == notification_id) - .first() ) + notification = db.session.execute(stmt).scalars().first() if not notification: current_app.logger.info( @@ -156,9 +157,8 @@ def update_notification_status_by_id( @autocommit def update_notification_status_by_reference(reference, status): # this is used to update emails - notification = Notification.query.filter( - Notification.reference == reference - ).first() + stmt = select(Notification).filter(Notification.reference == reference) + notification = db.session.execute(stmt).scalars().first() if not notification: current_app.logger.error( @@ -200,19 +200,20 @@ def get_notifications_for_job( def dao_get_notification_count_for_job_id(*, job_id): - return Notification.query.filter_by(job_id=job_id).count() + stmt = select(func.count(Notification.id)).filter_by(job_id=job_id) + return db.session.execute(stmt).scalar() def dao_get_notification_count_for_service(*, service_id): - notification_count = Notification.query.filter_by(service_id=service_id).count() - return notification_count + stmt = select(func.count(Notification.id)).filter_by(service_id=service_id) + return db.session.execute(stmt).scalar() def dao_get_failed_notification_count(): - failed_count = Notification.query.filter_by( + stmt = select(func.count(Notification.id)).filter_by( status=NotificationStatus.FAILED - ).count() - return failed_count + ) + return db.session.execute(stmt).scalar() def get_notification_with_personalisation(service_id, notification_id, key_type): @@ -220,11 +221,12 @@ def get_notification_with_personalisation(service_id, notification_id, key_type) if key_type: filter_dict["key_type"] = key_type - return ( - Notification.query.filter_by(**filter_dict) + stmt = ( + select(Notification) + .filter_by(**filter_dict) .options(joinedload(Notification.template)) - .one() ) + return db.session.execute(stmt).scalars().one() def get_notification_by_id(notification_id, service_id=None, _raise=False): @@ -233,9 +235,13 @@ def get_notification_by_id(notification_id, service_id=None, _raise=False): if service_id: filters.append(Notification.service_id == service_id) - query = Notification.query.filter(*filters) + stmt = select(Notification).filter(*filters) - return query.one() if _raise else query.first() + return ( + db.session.execute(stmt).scalars().one() + if _raise + else db.session.execute(stmt).scalars().first() + ) def get_notifications_for_service( @@ -415,12 +421,13 @@ def move_notifications_to_notification_history( deleted += delete_count_per_call # Deleting test Notifications, test notifications are not persisted to NotificationHistory - Notification.query.filter( + stmt = delete(Notification).filter( Notification.notification_type == notification_type, Notification.service_id == service_id, Notification.created_at < timestamp_to_delete_backwards_from, Notification.key_type == KeyType.TEST, - ).delete(synchronize_session=False) + ) + db.session.execute(stmt) db.session.commit() return deleted @@ -442,8 +449,9 @@ def dao_timeout_notifications(cutoff_time, limit=100000): current_statuses = [NotificationStatus.SENDING, NotificationStatus.PENDING] new_status = NotificationStatus.TEMPORARY_FAILURE - notifications = ( - Notification.query.filter( + stmt = ( + select(Notification) + .filter( Notification.created_at < cutoff_time, Notification.status.in_(current_statuses), Notification.notification_type.in_( @@ -451,14 +459,15 @@ def dao_timeout_notifications(cutoff_time, limit=100000): ), ) .limit(limit) - .all() ) + notifications = db.session.execute(stmt).scalars().all() - Notification.query.filter( - Notification.id.in_([n.id for n in notifications]), - ).update( - {"status": new_status, "updated_at": updated_at}, synchronize_session=False + stmt = ( + update(Notification) + .filter(Notification.id.in_([n.id for n in notifications])) + .values({"status": new_status, "updated_at": updated_at}) ) + db.session.execute(stmt) db.session.commit() return notifications @@ -466,15 +475,23 @@ def dao_timeout_notifications(cutoff_time, limit=100000): @autocommit def dao_update_notifications_by_reference(references, update_dict): - updated_count = Notification.query.filter( - Notification.reference.in_(references) - ).update(update_dict, synchronize_session=False) + stmt = ( + update(Notification) + .filter(Notification.reference.in_(references)) + .values(update_dict) + ) + result = db.session.execute(stmt) + updated_count = result.rowcount updated_history_count = 0 if updated_count != len(references): - updated_history_count = NotificationHistory.query.filter( - NotificationHistory.reference.in_(references) - ).update(update_dict, synchronize_session=False) + stmt = ( + update(NotificationHistory) + .filter(NotificationHistory.reference.in_(references)) + .values(update_dict) + ) + result = db.session.execute(stmt) + updated_history_count = result.rowcount return updated_count, updated_history_count @@ -541,18 +558,21 @@ def dao_get_notifications_by_recipient_or_reference( def dao_get_notification_by_reference(reference): - return Notification.query.filter(Notification.reference == reference).one() + stmt = select(Notification).filter(Notification.reference == reference) + return db.session.execute(stmt).scalars().one() def dao_get_notification_history_by_reference(reference): try: # This try except is necessary because in test keys and research mode does not create notification history. # Otherwise we could just search for the NotificationHistory object - return Notification.query.filter(Notification.reference == reference).one() + stmt = select(Notification).filter(Notification.reference == reference) + return db.session.execute(stmt).scalars().one() except NoResultFound: - return NotificationHistory.query.filter( + stmt = select(NotificationHistory).filter( NotificationHistory.reference == reference - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_notifications_processing_time_stats(start_date, end_date): @@ -590,11 +610,12 @@ def dao_get_notifications_processing_time_stats(start_date, end_date): def dao_get_last_notification_added_for_job_id(job_id): - last_notification_added = ( - Notification.query.filter(Notification.job_id == job_id) + stmt = ( + select(Notification) + .filter(Notification.job_id == job_id) .order_by(Notification.job_row_number.desc()) - .first() ) + last_notification_added = db.session.execute(stmt).scalars().first() return last_notification_added @@ -602,11 +623,12 @@ def dao_get_last_notification_added_for_job_id(job_id): def notifications_not_yet_sent(should_be_sending_after_seconds, notification_type): older_than_date = utc_now() - timedelta(seconds=should_be_sending_after_seconds) - notifications = Notification.query.filter( + stmt = select(Notification).filter( Notification.created_at <= older_than_date, Notification.notification_type == notification_type, Notification.status == NotificationStatus.CREATED, - ).all() + ) + notifications = db.session.execute(stmt).scalars().all() return notifications diff --git a/app/dao/organization_dao.py b/app/dao/organization_dao.py index 9e44bcdd5..668ac6c25 100644 --- a/app/dao/organization_dao.py +++ b/app/dao/organization_dao.py @@ -1,3 +1,4 @@ +from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import func from app import db @@ -6,55 +7,57 @@ def dao_get_organizations(): - return Organization.query.order_by( + stmt = select(Organization).order_by( Organization.active.desc(), Organization.name.asc() - ).all() + ) + return db.session.execute(stmt).scalars().all() def dao_count_organizations_with_live_services(): - return ( - db.session.query(Organization.id) + stmt = ( + select(func.count(func.distinct(Organization.id))) .join(Organization.services) .filter( Service.active.is_(True), Service.restricted.is_(False), Service.count_as_live.is_(True), ) - .distinct() - .count() ) + return db.session.execute(stmt).scalar() or 0 def dao_get_organization_services(organization_id): - return Organization.query.filter_by(id=organization_id).one().services + stmt = select(Organization).filter_by(id=organization_id) + return db.session.execute(stmt).scalars().one().services def dao_get_organization_live_services(organization_id): - return Service.query.filter_by( - organization_id=organization_id, restricted=False - ).all() + stmt = select(Service).filter_by(organization_id=organization_id, restricted=False) + return db.session.execute(stmt).scalars().all() def dao_get_organization_by_id(organization_id): - return Organization.query.filter_by(id=organization_id).one() + stmt = select(Organization).filter_by(id=organization_id) + return db.session.execute(stmt).scalars().one() def dao_get_organization_by_email_address(email_address): email_address = email_address.lower().replace(".gsi.gov.uk", ".gov.uk") - - for domain in Domain.query.order_by(func.char_length(Domain.domain).desc()).all(): + stmt = select(Domain).order_by(func.char_length(Domain.domain).desc()) + domains = db.session.execute(stmt).scalars().all() + for domain in domains: if email_address.endswith( "@{}".format(domain.domain) ) or email_address.endswith(".{}".format(domain.domain)): - return Organization.query.filter_by(id=domain.organization_id).one() + stmt = select(Organization).filter_by(id=domain.organization_id) + return db.session.execute(stmt).scalars().one() return None def dao_get_organization_by_service_id(service_id): - return ( - Organization.query.join(Organization.services).filter_by(id=service_id).first() - ) + stmt = select(Organization).join(Organization.services).filter_by(id=service_id) + return db.session.execute(stmt).scalars().first() @autocommit @@ -65,10 +68,14 @@ def dao_create_organization(organization): @autocommit def dao_update_organization(organization_id, **kwargs): domains = kwargs.pop("domains", None) - num_updated = Organization.query.filter_by(id=organization_id).update(kwargs) + stmt = ( + update(Organization).where(Organization.id == organization_id).values(**kwargs) + ) + num_updated = db.session.execute(stmt).rowcount if isinstance(domains, list): - Domain.query.filter_by(organization_id=organization_id).delete() + stmt = delete(Domain).filter_by(organization_id=organization_id) + db.session.execute(stmt) db.session.bulk_save_objects( [ Domain(domain=domain.lower(), organization_id=organization_id) @@ -76,7 +83,7 @@ def dao_update_organization(organization_id, **kwargs): ] ) - organization = Organization.query.get(organization_id) + organization = db.session.get(Organization, organization_id) if "organization_type" in kwargs: _update_organization_services( organization, "organization_type", only_where_none=False @@ -101,7 +108,8 @@ def _update_organization_services(organization, attribute, only_where_none=True) @autocommit @version_class(Service) def dao_add_service_to_organization(service, organization_id): - organization = Organization.query.filter_by(id=organization_id).one() + stmt = select(Organization).filter_by(id=organization_id) + organization = db.session.execute(stmt).scalars().one() service.organization_id = organization_id service.organization_type = organization.organization_type @@ -122,7 +130,8 @@ def dao_get_users_for_organization(organization_id): @autocommit def dao_add_user_to_organization(organization_id, user_id): organization = dao_get_organization_by_id(organization_id) - user = User.query.filter_by(id=user_id).one() + stmt = select(User).filter_by(id=user_id) + user = db.session.execute(stmt).scalars().one() user.organizations.append(organization) db.session.add(organization) return user diff --git a/app/dao/service_permissions_dao.py b/app/dao/service_permissions_dao.py index e459b6e56..0793b35b6 100644 --- a/app/dao/service_permissions_dao.py +++ b/app/dao/service_permissions_dao.py @@ -1,12 +1,14 @@ +from sqlalchemy import delete, select + from app import db from app.dao.dao_utils import autocommit from app.models import ServicePermission def dao_fetch_service_permissions(service_id): - return ServicePermission.query.filter( - ServicePermission.service_id == service_id - ).all() + + stmt = select(ServicePermission).filter(ServicePermission.service_id == service_id) + return db.session.execute(stmt).scalars().all() @autocommit @@ -16,9 +18,11 @@ def dao_add_service_permission(service_id, permission): def dao_remove_service_permission(service_id, permission): - deleted = ServicePermission.query.filter( + + stmt = delete(ServicePermission).where( ServicePermission.service_id == service_id, ServicePermission.permission == permission, - ).delete() + ) + result = db.session.execute(stmt) db.session.commit() - return deleted + return result.rowcount diff --git a/app/dao/service_sms_sender_dao.py b/app/dao/service_sms_sender_dao.py index 9224cf09d..82796b05f 100644 --- a/app/dao/service_sms_sender_dao.py +++ b/app/dao/service_sms_sender_dao.py @@ -1,4 +1,4 @@ -from sqlalchemy import desc +from sqlalchemy import desc, select from app import db from app.dao.dao_utils import autocommit @@ -17,17 +17,20 @@ def insert_service_sms_sender(service, sms_sender): def dao_get_service_sms_senders_by_id(service_id, service_sms_sender_id): - return ServiceSmsSender.query.filter_by( + stmt = select(ServiceSmsSender).filter_by( id=service_sms_sender_id, service_id=service_id, archived=False - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_sms_senders_by_service_id(service_id): - return ( - ServiceSmsSender.query.filter_by(service_id=service_id, archived=False) + + stmt = ( + select(ServiceSmsSender) + .filter_by(service_id=service_id, archived=False) .order_by(desc(ServiceSmsSender.is_default)) - .all() ) + return db.session.execute(stmt).scalars().all() @autocommit diff --git a/app/dao/service_user_dao.py b/app/dao/service_user_dao.py index 0b991a4fc..d60c92ba6 100644 --- a/app/dao/service_user_dao.py +++ b/app/dao/service_user_dao.py @@ -1,25 +1,23 @@ +from sqlalchemy import select + from app import db from app.dao.dao_utils import autocommit from app.models import ServiceUser, User def dao_get_service_user(user_id, service_id): - # TODO: This has been changed to account for the test case failure - # that used this method but have any service user to return. Somehow, this - # started to throw an error with one() method in sqlalchemy 2.0 unlike 1.4 - return ServiceUser.query.filter_by( - user_id=user_id, service_id=service_id - ).one_or_none() + stmt = select(ServiceUser).filter_by(user_id=user_id, service_id=service_id) + return db.session.execute(stmt).scalars().one_or_none() def dao_get_active_service_users(service_id): - query = ( - db.session.query(ServiceUser) + + stmt = ( + select(ServiceUser) .join(User, User.id == ServiceUser.user_id) .filter(User.state == "active", ServiceUser.service_id == service_id) ) - - return query.all() + return db.session.execute(stmt).scalars().all() def dao_get_service_users_by_user_id(user_id): diff --git a/app/dao/services_dao.py b/app/dao/services_dao.py index 19755edfe..1f8956865 100644 --- a/app/dao/services_dao.py +++ b/app/dao/services_dao.py @@ -2,7 +2,7 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import Float, cast, select +from sqlalchemy import Float, cast, delete, select from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import and_, asc, case, func @@ -51,34 +51,42 @@ def dao_fetch_all_services(only_active=False): - query = Service.query.order_by(asc(Service.created_at)).options( - joinedload(Service.users) - ) + + stmt = select(Service) if only_active: - query = query.filter(Service.active) + stmt = stmt.where(Service.active) + + stmt = stmt.order_by(asc(Service.created_at)).options(joinedload(Service.users)) - return query.all() + result = db.session.execute(stmt) + return result.unique().scalars().all() def get_services_by_partial_name(service_name): service_name = escape_special_characters(service_name) - return Service.query.filter(Service.name.ilike("%{}%".format(service_name))).all() + stmt = select(Service).where(Service.name.ilike("%{}%".format(service_name))) + result = db.session.execute(stmt) + return result.scalars().all() def dao_count_live_services(): - return Service.query.filter_by( - active=True, - restricted=False, - count_as_live=True, - ).count() + stmt = ( + select(func.count()) + .select_from(Service) + .where( + Service.active, Service.count_as_live, Service.restricted == False # noqa + ) + ) + result = db.session.execute(stmt) + return result.scalar() # Retrieves the count def dao_fetch_live_services_data(): year_start_date, year_end_date = get_current_calendar_year() most_recent_annual_billing = ( - db.session.query( + select( AnnualBilling.service_id, func.max(AnnualBilling.financial_year_start).label("year"), ) @@ -86,13 +94,17 @@ def dao_fetch_live_services_data(): .subquery() ) - this_year_ft_billing = FactBilling.query.filter( - FactBilling.local_date >= year_start_date, - FactBilling.local_date <= year_end_date, - ).subquery() + this_year_ft_billing = ( + select(FactBilling) + .filter( + FactBilling.local_date >= year_start_date, + FactBilling.local_date <= year_end_date, + ) + .subquery() + ) - data = ( - db.session.query( + stmt = ( + select( Service.id.label("service_id"), Service.name.label("service_name"), Organization.name.label("organization_name"), @@ -156,8 +168,9 @@ def dao_fetch_live_services_data(): AnnualBilling.free_sms_fragment_limit, ) .order_by(asc(Service.go_live_at)) - .all() ) + + data = db.session.execute(stmt).all() results = [] for row in data: existing_service = next( @@ -183,48 +196,55 @@ def dao_fetch_service_by_id(service_id, only_active=False): stmt = stmt.where(Service.active) result = db.session.execute(stmt) - return result.unique().scalars().one() + return result.unique().scalars().unique().one() def dao_fetch_service_by_inbound_number(number): - inbound_number = InboundNumber.query.filter( + stmt = select(InboundNumber).where( InboundNumber.number == number, InboundNumber.active - ).first() + ) + result = db.session.execute(stmt) + inbound_number = result.scalars().first() if not inbound_number: return None - return Service.query.filter(Service.id == inbound_number.service_id).first() + stmt = select(Service).where(Service.id == inbound_number.service_id) + result = db.session.execute(stmt) + return result.scalars().first() def dao_fetch_service_by_id_with_api_keys(service_id, only_active=False): - query = Service.query.filter_by(id=service_id).options(joinedload(Service.api_keys)) - + stmt = ( + select(Service).filter_by(id=service_id).options(joinedload(Service.api_keys)) + ) if only_active: - query = query.filter(Service.active) - - return query.one() + stmt = stmt.filter(Service.active) + return db.session.execute(stmt).scalars().unique().one() def dao_fetch_all_services_by_user(user_id, only_active=False): - query = ( - Service.query.filter(Service.users.any(id=user_id)) + + stmt = ( + select(Service) + .filter(Service.users.any(id=user_id)) .order_by(asc(Service.created_at)) .options(joinedload(Service.users)) ) - if only_active: - query = query.filter(Service.active) - - return query.all() + stmt = stmt.filter(Service.active) + return db.session.execute(stmt).scalars().unique().all() def dao_fetch_all_services_created_by_user(user_id): - query = Service.query.filter_by(created_by_id=user_id).order_by( - asc(Service.created_at) + + stmt = ( + select(Service) + .filter_by(created_by_id=user_id) + .order_by(asc(Service.created_at)) ) - return query.all() + return db.session.execute(stmt).scalars().all() @autocommit @@ -234,16 +254,15 @@ def dao_fetch_all_services_created_by_user(user_id): VersionOptions(Template, history_class=TemplateHistory, must_write_history=False), ) def dao_archive_service(service_id): - # have to eager load templates and api keys so that we don't flush when we loop through them - # to ensure that db.session still contains the models when it comes to creating history objects - service = ( - Service.query.options( + stmt = ( + select(Service) + .options( joinedload(Service.templates).subqueryload(Template.template_redacted), joinedload(Service.api_keys), ) .filter(Service.id == service_id) - .one() ) + service = db.session.execute(stmt).scalars().unique().one() service.active = False service.name = get_archived_db_column_value(service.name) @@ -259,11 +278,14 @@ def dao_archive_service(service_id): def dao_fetch_service_by_id_and_user(service_id, user_id): - return ( - Service.query.filter(Service.users.any(id=user_id), Service.id == service_id) + + stmt = ( + select(Service) + .filter(Service.users.any(id=user_id), Service.id == service_id) .options(joinedload(Service.users)) - .one() ) + result = db.session.execute(stmt).scalar_one() + return result @autocommit @@ -366,39 +388,40 @@ def dao_remove_user_from_service(service, user): def delete_service_and_all_associated_db_objects(service): - def _delete_commit(query): - query.delete(synchronize_session=False) + def _delete_commit(stmt): + db.session.execute(stmt) db.session.commit() - subq = db.session.query(Template.id).filter_by(service=service).subquery() - _delete_commit( - TemplateRedacted.query.filter(TemplateRedacted.template_id.in_(subq)) - ) + subq = select(Template.id).filter_by(service=service).subquery() + + stmt = delete(TemplateRedacted).filter(TemplateRedacted.template_id.in_(subq)) + _delete_commit(stmt) + + _delete_commit(delete(ServiceSmsSender).filter_by(service=service)) + _delete_commit(delete(ServiceEmailReplyTo).filter_by(service=service)) + _delete_commit(delete(InvitedUser).filter_by(service=service)) + _delete_commit(delete(Permission).filter_by(service=service)) + _delete_commit(delete(NotificationHistory).filter_by(service=service)) + _delete_commit(delete(Notification).filter_by(service=service)) + _delete_commit(delete(Job).filter_by(service=service)) + _delete_commit(delete(Template).filter_by(service=service)) + _delete_commit(delete(TemplateHistory).filter_by(service_id=service.id)) + _delete_commit(delete(ServicePermission).filter_by(service_id=service.id)) + _delete_commit(delete(ApiKey).filter_by(service=service)) + _delete_commit(delete(ApiKey.get_history_model()).filter_by(service_id=service.id)) + _delete_commit(delete(AnnualBilling).filter_by(service_id=service.id)) - _delete_commit(ServiceSmsSender.query.filter_by(service=service)) - _delete_commit(ServiceEmailReplyTo.query.filter_by(service=service)) - _delete_commit(InvitedUser.query.filter_by(service=service)) - _delete_commit(Permission.query.filter_by(service=service)) - _delete_commit(NotificationHistory.query.filter_by(service=service)) - _delete_commit(Notification.query.filter_by(service=service)) - _delete_commit(Job.query.filter_by(service=service)) - _delete_commit(Template.query.filter_by(service=service)) - _delete_commit(TemplateHistory.query.filter_by(service_id=service.id)) - _delete_commit(ServicePermission.query.filter_by(service_id=service.id)) - _delete_commit(ApiKey.query.filter_by(service=service)) - _delete_commit(ApiKey.get_history_model().query.filter_by(service_id=service.id)) - _delete_commit(AnnualBilling.query.filter_by(service_id=service.id)) - - verify_codes = VerifyCode.query.join(User).filter( - User.id.in_([x.id for x in service.users]) + stmt = ( + select(VerifyCode).join(User).filter(User.id.in_([x.id for x in service.users])) ) + verify_codes = db.session.execute(stmt).scalars().all() list(map(db.session.delete, verify_codes)) db.session.commit() users = [x for x in service.users] for user in users: user.organizations = [] service.users.remove(user) - _delete_commit(Service.get_history_model().query.filter_by(id=service.id)) + _delete_commit(delete(Service.get_history_model()).filter_by(id=service.id)) db.session.delete(service) db.session.commit() for user in users: @@ -409,8 +432,8 @@ def _delete_commit(query): def dao_fetch_todays_stats_for_service(service_id): today = utc_now().date() start_date = get_midnight_in_utc(today) - return ( - db.session.query( + stmt = ( + select( Notification.notification_type, Notification.status, func.count(Notification.id).label("count"), @@ -424,16 +447,16 @@ def dao_fetch_todays_stats_for_service(service_id): Notification.notification_type, Notification.status, ) - .all() ) + return db.session.execute(stmt).all() def dao_fetch_stats_for_service_from_days(service_id, start_date, end_date): start_date = get_midnight_in_utc(start_date) end_date = get_midnight_in_utc(end_date + timedelta(days=1)) - return ( - db.session.query( + stmt = ( + select( NotificationAllTimeView.notification_type, NotificationAllTimeView.status, func.date_trunc("day", NotificationAllTimeView.created_at).label("day"), @@ -450,8 +473,8 @@ def dao_fetch_stats_for_service_from_days(service_id, start_date, end_date): NotificationAllTimeView.status, func.date_trunc("day", NotificationAllTimeView.created_at), ) - .all() ) + return db.session.execute(stmt).scalars().all() def dao_fetch_stats_for_service_from_days_for_user( @@ -460,13 +483,14 @@ def dao_fetch_stats_for_service_from_days_for_user( start_date = get_midnight_in_utc(start_date) end_date = get_midnight_in_utc(end_date + timedelta(days=1)) - return ( - db.session.query( + stmt = ( + select( NotificationAllTimeView.notification_type, NotificationAllTimeView.status, func.date_trunc("day", NotificationAllTimeView.created_at).label("day"), func.count(NotificationAllTimeView.id).label("count"), ) + .select_from(NotificationAllTimeView) .filter( NotificationAllTimeView.service_id == service_id, NotificationAllTimeView.key_type != KeyType.TEST, @@ -479,8 +503,8 @@ def dao_fetch_stats_for_service_from_days_for_user( NotificationAllTimeView.status, func.date_trunc("day", NotificationAllTimeView.created_at), ) - .all() ) + return db.session.execute(stmt).scalars().all() def dao_fetch_todays_stats_for_all_services( @@ -491,7 +515,7 @@ def dao_fetch_todays_stats_for_all_services( end_date = get_midnight_in_utc(today + timedelta(days=1)) subquery = ( - db.session.query( + select( Notification.notification_type, Notification.status, Notification.service_id, @@ -510,8 +534,8 @@ def dao_fetch_todays_stats_for_all_services( subquery = subquery.subquery() - query = ( - db.session.query( + stmt = ( + select( Service.id.label("service_id"), Service.name, Service.restricted, @@ -526,9 +550,9 @@ def dao_fetch_todays_stats_for_all_services( ) if only_active: - query = query.filter(Service.active) + stmt = stmt.filter(Service.active) - return query.all() + return db.session.execute(stmt).all() @autocommit @@ -537,15 +561,13 @@ def dao_fetch_todays_stats_for_all_services( VersionOptions(Service), ) def dao_suspend_service(service_id): - # have to eager load api keys so that we don't flush when we loop through them - # to ensure that db.session still contains the models when it comes to creating history objects - service = ( - Service.query.options( - joinedload(Service.api_keys), - ) + + stmt = ( + select(Service) + .options(joinedload(Service.api_keys)) .filter(Service.id == service_id) - .one() ) + service = db.session.execute(stmt).scalars().unique().one() for api_key in service.api_keys: if not api_key.expiry_date: @@ -557,19 +579,22 @@ def dao_suspend_service(service_id): @autocommit @version_class(Service) def dao_resume_service(service_id): - service = Service.query.get(service_id) + service = db.session.get(Service, service_id) + service.active = True def dao_fetch_active_users_for_service(service_id): - query = User.query.filter(User.services.any(id=service_id), User.state == "active") - return query.all() + stmt = select(User).where(User.services.any(id=service_id), User.state == "active") + result = db.session.execute(stmt) + return result.scalars().all() def dao_find_services_sending_to_tv_numbers(start_date, end_date, threshold=500): - return ( - db.session.query( + + stmt = ( + select( Notification.service_id.label("service_id"), func.count(Notification.id).label("notification_count"), ) @@ -587,13 +612,13 @@ def dao_find_services_sending_to_tv_numbers(start_date, end_date, threshold=500) Notification.service_id, ) .having(func.count(Notification.id) > threshold) - .all() ) + return db.session.execute(stmt).all() def dao_find_services_with_high_failure_rates(start_date, end_date, threshold=10000): subquery = ( - db.session.query( + select( func.count(Notification.id).label("total_count"), Notification.service_id.label("service_id"), ) @@ -614,8 +639,8 @@ def dao_find_services_with_high_failure_rates(start_date, end_date, threshold=10 subquery = subquery.subquery() - query = ( - db.session.query( + stmt = ( + select( Notification.service_id.label("service_id"), func.count(Notification.id).label("permanent_failure_count"), subquery.c.total_count.label("total_count"), @@ -643,17 +668,19 @@ def dao_find_services_with_high_failure_rates(start_date, end_date, threshold=10 ) ) - return query.all() + return db.session.execute(stmt).all() def get_live_services_with_organization(): - query = ( - db.session.query( + + stmt = ( + select( Service.id.label("service_id"), Service.name.label("service_name"), Organization.id.label("organization_id"), Organization.name.label("organization_name"), ) + .select_from(Service) .outerjoin(Service.organization) .filter( Service.count_as_live.is_(True), @@ -663,14 +690,15 @@ def get_live_services_with_organization(): .order_by(Organization.name, Service.name) ) - return query.all() + return db.session.execute(stmt).all() def fetch_notification_stats_for_service_by_month_by_user( start_date, end_date, service_id, user_id ): - return ( - db.session.query( + + stmt = ( + select( func.date_trunc("month", NotificationAllTimeView.created_at).label("month"), NotificationAllTimeView.notification_type, (NotificationAllTimeView.status).label("notification_status"), @@ -688,8 +716,8 @@ def fetch_notification_stats_for_service_by_month_by_user( NotificationAllTimeView.notification_type, NotificationAllTimeView.status, ) - .all() ) + return db.session.execute(stmt).all() def get_specific_days_stats(data, start_date, days=None, end_date=None): diff --git a/app/dao/template_folder_dao.py b/app/dao/template_folder_dao.py index ae1224179..269f407e0 100644 --- a/app/dao/template_folder_dao.py +++ b/app/dao/template_folder_dao.py @@ -1,16 +1,20 @@ +from sqlalchemy import select + from app import db from app.dao.dao_utils import autocommit from app.models import TemplateFolder def dao_get_template_folder_by_id_and_service_id(template_folder_id, service_id): - return TemplateFolder.query.filter( + stmt = select(TemplateFolder).filter( TemplateFolder.id == template_folder_id, TemplateFolder.service_id == service_id - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_valid_template_folders_by_id(folder_ids): - return TemplateFolder.query.filter(TemplateFolder.id.in_(folder_ids)).all() + stmt = select(TemplateFolder).filter(TemplateFolder.id.in_(folder_ids)) + return db.session.execute(stmt).scalars().all() @autocommit diff --git a/app/dao/templates_dao.py b/app/dao/templates_dao.py index 55d4363d6..7c5d7459e 100644 --- a/app/dao/templates_dao.py +++ b/app/dao/templates_dao.py @@ -1,6 +1,6 @@ import uuid -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select from app import db from app.dao.dao_utils import VersionOptions, autocommit, version_class @@ -46,24 +46,29 @@ def dao_redact_template(template, user_id): def dao_get_template_by_id_and_service_id(template_id, service_id, version=None): if version is not None: - return TemplateHistory.query.filter_by( + stmt = select(TemplateHistory).filter_by( id=template_id, hidden=False, service_id=service_id, version=version - ).one() - return Template.query.filter_by( + ) + return db.session.execute(stmt).scalars().one() + stmt = select(Template).filter_by( id=template_id, hidden=False, service_id=service_id - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_get_template_by_id(template_id, version=None): if version is not None: - return TemplateHistory.query.filter_by(id=template_id, version=version).one() - return Template.query.filter_by(id=template_id).one() + stmt = select(TemplateHistory).filter_by(id=template_id, version=version) + return db.session.execute(stmt).scalars().one() + stmt = select(Template).filter_by(id=template_id) + return db.session.execute(stmt).scalars().one() def dao_get_all_templates_for_service(service_id, template_type=None): if template_type is not None: - return ( - Template.query.filter_by( + stmt = ( + select(Template) + .filter_by( service_id=service_id, template_type=template_type, hidden=False, @@ -73,26 +78,27 @@ def dao_get_all_templates_for_service(service_id, template_type=None): asc(Template.name), asc(Template.template_type), ) - .all() ) - - return ( - Template.query.filter_by(service_id=service_id, hidden=False, archived=False) + return db.session.execute(stmt).scalars().all() + stmt = ( + select(Template) + .filter_by(service_id=service_id, hidden=False, archived=False) .order_by( asc(Template.name), asc(Template.template_type), ) - .all() ) + return db.session.execute(stmt).scalars().all() def dao_get_template_versions(service_id, template_id): - return ( - TemplateHistory.query.filter_by( + stmt = ( + select(TemplateHistory) + .filter_by( service_id=service_id, id=template_id, hidden=False, ) .order_by(desc(TemplateHistory.version)) - .all() ) + return db.session.execute(stmt).scalars().all() diff --git a/app/dao/users_dao.py b/app/dao/users_dao.py index 897bb1b9e..690ecc7f9 100644 --- a/app/dao/users_dao.py +++ b/app/dao/users_dao.py @@ -4,7 +4,7 @@ import sqlalchemy from flask import current_app -from sqlalchemy import func, text +from sqlalchemy import delete, func, select, text from sqlalchemy.orm import joinedload from app import db @@ -37,8 +37,8 @@ def get_login_gov_user(login_uuid, email_address): login.gov uuids are. Eventually the code that checks by email address should be removed. """ - - user = User.query.filter_by(login_uuid=login_uuid).first() + stmt = select(User).filter_by(login_uuid=login_uuid) + user = db.session.execute(stmt).scalars().first() if user: if user.email_address != email_address: try: @@ -54,7 +54,8 @@ def get_login_gov_user(login_uuid, email_address): return user # Remove this 1 July 2025, all users should have login.gov uuids by now - user = User.query.filter(User.email_address.ilike(email_address)).first() + stmt = select(User).filter(User.email_address.ilike(email_address)) + user = db.session.execute(stmt).scalars().first() if user: save_user_attribute(user, {"login_uuid": login_uuid}) @@ -102,24 +103,27 @@ def create_user_code(user, code, code_type): def get_user_code(user, code, code_type): # Get the most recent codes to try and reduce the # time searching for the correct code. - codes = VerifyCode.query.filter_by(user=user, code_type=code_type).order_by( - VerifyCode.created_at.desc() + stmt = ( + select(VerifyCode) + .filter_by(user=user, code_type=code_type) + .order_by(VerifyCode.created_at.desc()) ) + codes = db.session.execute(stmt).scalars().all() return next((x for x in codes if x.check_code(code)), None) def delete_codes_older_created_more_than_a_day_ago(): - deleted = ( - db.session.query(VerifyCode) - .filter(VerifyCode.created_at < utc_now() - timedelta(hours=24)) - .delete() + stmt = delete(VerifyCode).filter( + VerifyCode.created_at < utc_now() - timedelta(hours=24) ) + + deleted = db.session.execute(stmt) db.session.commit() return deleted def use_user_code(id): - verify_code = VerifyCode.query.get(id) + verify_code = db.session.get(VerifyCode, id) verify_code.code_used = True db.session.add(verify_code) db.session.commit() @@ -131,36 +135,42 @@ def delete_model_user(user): def delete_user_verify_codes(user): - VerifyCode.query.filter_by(user=user).delete() + stmt = delete(VerifyCode).filter_by(user=user) + db.session.execute(stmt) db.session.commit() def count_user_verify_codes(user): - query = VerifyCode.query.filter( + stmt = select(func.count(VerifyCode.id)).filter( VerifyCode.user == user, VerifyCode.expiry_datetime > utc_now(), VerifyCode.code_used.is_(False), ) - return query.count() + result = db.session.execute(stmt).scalar() + return result or 0 def get_user_by_id(user_id=None): if user_id: - return User.query.filter_by(id=user_id).one() - return User.query.filter_by().all() + stmt = select(User).filter_by(id=user_id) + return db.session.execute(stmt).scalars().one() + return get_users() def get_users(): - return User.query.all() + stmt = select(User) + return db.session.execute(stmt).scalars().all() def get_user_by_email(email): - return User.query.filter(func.lower(User.email_address) == func.lower(email)).one() + stmt = select(User).filter(func.lower(User.email_address) == func.lower(email)) + return db.session.execute(stmt).scalars().one() def get_users_by_partial_email(email): email = escape_special_characters(email) - return User.query.filter(User.email_address.ilike("%{}%".format(email))).all() + stmt = select(User).filter(User.email_address.ilike("%{}%".format(email))) + return db.session.execute(stmt).scalars().all() def increment_failed_login_count(user): @@ -188,16 +198,17 @@ def get_user_and_accounts(user_id): # TODO: With sqlalchemy 2.0 change as below because of the breaking change # at User.organizations.services, we need to verify that the below subqueryload # that we have put is functionally doing the same thing as before - return ( - User.query.filter(User.id == user_id) + stmt = ( + select(User) + .filter(User.id == user_id) .options( # eagerly load the user's services and organizations, and also the service's org and vice versa # (so we can see if the user knows about it) joinedload(User.services).joinedload(Service.organization), joinedload(User.organizations).subqueryload(Organization.services), ) - .one() ) + return db.session.execute(stmt).scalars().unique().one() @autocommit diff --git a/app/delivery/send_to_providers.py b/app/delivery/send_to_providers.py index 745b46cab..07763823f 100644 --- a/app/delivery/send_to_providers.py +++ b/app/delivery/send_to_providers.py @@ -98,17 +98,7 @@ def send_sms_to_provider(notification): # TODO This is temporary to test the capability of validating phone numbers # The future home of the validation is TBD - if "+" not in recipient: - recipient_lookup = f"+{recipient}" - else: - recipient_lookup = recipient - if recipient_lookup in current_app.config[ - "SIMULATED_SMS_NUMBERS" - ] and os.getenv("NOTIFY_ENVIRONMENT") in ["development", "test"]: - current_app.logger.info(hilite("#validate-phone-number fired")) - aws_pinpoint_client.validate_phone_number("01", recipient) - else: - current_app.logger.info(hilite("#validate-phone-number not fired")) + _experimentally_validate_phone_numbers(recipient) sender_numbers = get_sender_numbers(notification) if notification.reply_to_text not in sender_numbers: @@ -145,6 +135,18 @@ def send_sms_to_provider(notification): return message_id +def _experimentally_validate_phone_numbers(recipient): + if "+" not in recipient: + recipient_lookup = f"+{recipient}" + else: + recipient_lookup = recipient + if recipient_lookup in current_app.config["SIMULATED_SMS_NUMBERS"] and os.getenv( + "NOTIFY_ENVIRONMENT" + ) in ["development", "test"]: + current_app.logger.info(hilite("#validate-phone-number fired")) + aws_pinpoint_client.validate_phone_number("01", recipient) + + def _get_verify_code(notification): key = f"2facode-{notification.id}".replace(" ", "") recipient = redis_store.get(key) diff --git a/app/service/rest.py b/app/service/rest.py index 070f13457..6441b74b7 100644 --- a/app/service/rest.py +++ b/app/service/rest.py @@ -453,16 +453,6 @@ def get_all_notifications_for_service(service_id): data = notifications_filter_schema.load(MultiDict(request.get_json())) current_app.logger.debug(f"use POST, request {request.get_json()} data {data}") - if data.get("to"): - notification_type = ( - data.get("template_type")[0] if data.get("template_type") else None - ) - return search_for_notification_by_to_field( - service_id=service_id, - search_term=data["to"], - statuses=data.get("status"), - notification_type=notification_type, - ) page = data["page"] if "page" in data else 1 page_size = ( data["page_size"] @@ -583,53 +573,6 @@ def get_notification_for_service(service_id, notification_id): ) -def search_for_notification_by_to_field( - service_id, search_term, statuses, notification_type -): - results = notifications_dao.dao_get_notifications_by_recipient_or_reference( - service_id=service_id, - search_term=search_term, - statuses=statuses, - notification_type=notification_type, - page=1, - page_size=current_app.config["PAGE_SIZE"], - ) - - # We try and get the next page of results to work out if we need provide a pagination link to the next page - # in our response. Note, this was previously be done by having - # notifications_dao.dao_get_notifications_by_recipient_or_reference use count=False when calling - # Flask-Sqlalchemys `paginate'. But instead we now use this way because it is much more performant for - # services with many results (unlike using Flask SqlAlchemy `paginate` with `count=True`, this approach - # doesn't do an additional query to count all the results of which there could be millions but instead only - # asks for a single extra page of results). - next_page_of_pagination = notifications_dao.dao_get_notifications_by_recipient_or_reference( - service_id=service_id, - search_term=search_term, - statuses=statuses, - notification_type=notification_type, - page=2, - page_size=current_app.config["PAGE_SIZE"], - error_out=False, # False so that if there are no results, it doesn't end in aborting with a 404 - ) - - return ( - jsonify( - notifications=notification_with_template_schema.dump( - results.items, many=True - ), - links=get_prev_next_pagination_links( - 1, - len(next_page_of_pagination.items), - ".get_all_notifications_for_service", - statuses=statuses, - notification_type=notification_type, - service_id=service_id, - ), - ), - 200, - ) - - @service_blueprint.route("//notifications/monthly", methods=["GET"]) def get_monthly_notification_stats(service_id): # check service_id validity diff --git a/app/service_invite/rest.py b/app/service_invite/rest.py index 5728b3ed5..e7d0d4b20 100644 --- a/app/service_invite/rest.py +++ b/app/service_invite/rest.py @@ -32,7 +32,7 @@ register_errors(service_invite) -def _create_service_invite(invited_user, invite_link_host): +def _create_service_invite(invited_user, nonce): template_id = current_app.config["INVITATION_EMAIL_TEMPLATE_ID"] @@ -40,12 +40,6 @@ def _create_service_invite(invited_user, invite_link_host): service = Service.query.get(current_app.config["NOTIFY_SERVICE_ID"]) - token = generate_token( - str(invited_user.email_address), - current_app.config["SECRET_KEY"], - current_app.config["DANGEROUS_SALT"], - ) - # The raw permissions are in the form "a,b,c,d" # but need to be in the form ["a", "b", "c", "d"] data = {} @@ -59,7 +53,8 @@ def _create_service_invite(invited_user, invite_link_host): data["invited_user_email"] = invited_user.email_address url = os.environ["LOGIN_DOT_GOV_REGISTRATION_URL"] - url = url.replace("NONCE", token) + + url = url.replace("NONCE", nonce) # handed from data sent from admin. user_data_url_safe = get_user_data_url_safe(data) @@ -94,10 +89,16 @@ def _create_service_invite(invited_user, invite_link_host): @service_invite.route("/service//invite", methods=["POST"]) def create_invited_user(service_id): request_json = request.get_json() + try: + nonce = request_json.pop("nonce") + except KeyError: + current_app.logger.exception("nonce not found in submitted data.") + raise + invited_user = invited_user_schema.load(request_json) save_invited_user(invited_user) - _create_service_invite(invited_user, request_json.get("invite_link_host")) + _create_service_invite(invited_user, nonce) return jsonify(data=invited_user_schema.dump(invited_user)), 201 diff --git a/migrations/versions/0044_jobs_to_notification_hist.py b/migrations/versions/0044_jobs_to_notification_hist.py index e813833b4..3312d9a49 100644 --- a/migrations/versions/0044_jobs_to_notification_hist.py +++ b/migrations/versions/0044_jobs_to_notification_hist.py @@ -31,10 +31,10 @@ def upgrade(): # # go_live = datetime.datetime.strptime('2016-05-18', '%Y-%m-%d') # notifications_history_start_date = datetime.datetime.strptime('2016-06-26 23:21:55', '%Y-%m-%d %H:%M:%S') - # jobs = session.query(Job).join(Template).filter(Job.service_id == '95316ff0-e555-462d-a6e7-95d26fbfd091', + # stmt = select(Job).join(Template).filter(Job.service_id == '95316ff0-e555-462d-a6e7-95d26fbfd091', # Job.created_at >= go_live, # Job.created_at < notifications_history_start_date).all() - # + # jobs = db.session.execute(stmt).scalars().all() # for job in jobs: # for i in range(0, job.notifications_delivered): # notification = NotificationHistory(id=uuid.uuid4(), @@ -76,12 +76,11 @@ def downgrade(): # # go_live = datetime.datetime.strptime('2016-05-18', '%Y-%m-%d') # notifications_history_start_date = datetime.datetime.strptime('2016-06-26 23:21:55', '%Y-%m-%d %H:%M:%S') - # - # session.query(NotificationHistory).filter( + # stmt = delete(NotificationHistory).where( # NotificationHistory.created_at >= go_live, # NotificationHistory.service_id == '95316ff0-e555-462d-a6e7-95d26fbfd091', - # NotificationHistory.created_at < notifications_history_start_date).delete() - # + # NotificationHistory.created_at < notifications_history_start_date) + # session.execute(stmt) # session.commit() # ### end Alembic commands ### pass diff --git a/notifications_utils/sanitise_text.py b/notifications_utils/sanitise_text.py index 3e9da0764..750a2e49b 100644 --- a/notifications_utils/sanitise_text.py +++ b/notifications_utils/sanitise_text.py @@ -122,19 +122,15 @@ def is_arabic(cls, value): def is_punjabi(cls, value): # Gukmukhi script or Shahmukhi script - if regex.search(r"[\u0A00-\u0A7F]+", value): - return True - elif regex.search(r"[\u0600-\u06FF]+", value): - return True - elif regex.search(r"[\u0750-\u077F]+", value): - return True - elif regex.search(r"[\u08A0-\u08FF]+", value): - return True - elif regex.search(r"[\uFB50-\uFDFF]+", value): - return True - elif regex.search(r"[\uFE70-\uFEFF]+", value): - return True - elif regex.search(r"[\u0900-\u097F]+", value): + if ( + regex.search(r"[\u0A00-\u0A7F]+", value) + or regex.search(r"[\u0600-\u06FF]+", value) + or regex.search(r"[\u0750-\u077F]+", value) + or regex.search(r"[\u08A0-\u08FF]+", value) + or regex.search(r"[\uFB50-\uFDFF]+", value) + or regex.search(r"[\uFE70-\uFEFF]+", value) + or regex.search(r"[\u0900-\u097F]+", value) + ): return True return False @@ -156,33 +152,27 @@ def _is_extended_language_group_one(cls, value): @classmethod def _is_extended_language_group_two(cls, value): - if regex.search(r"\p{IsBuhid}", value): - return True - if regex.search(r"\p{IsCanadian_Aboriginal}", value): - return True - if regex.search(r"\p{IsCherokee}", value): - return True - if regex.search(r"\p{IsDevanagari}", value): - return True - if regex.search(r"\p{IsEthiopic}", value): - return True - if regex.search(r"\p{IsGeorgian}", value): + if ( + regex.search(r"\p{IsBuhid}", value) + or regex.search(r"\p{IsCanadian_Aboriginal}", value) + or regex.search(r"\p{IsCherokee}", value) + or regex.search(r"\p{IsDevanagari}", value) + or regex.search(r"\p{IsEthiopic}", value) + or regex.search(r"\p{IsGeorgian}", value) + ): return True return False @classmethod def _is_extended_language_group_three(cls, value): - if regex.search(r"\p{IsGreek}", value): - return True - if regex.search(r"\p{IsGujarati}", value): - return True - if regex.search(r"\p{IsHanunoo}", value): - return True - if regex.search(r"\p{IsHebrew}", value): - return True - if regex.search(r"\p{IsLimbu}", value): - return True - if regex.search(r"\p{IsKannada}", value): + if ( + regex.search(r"\p{IsGreek}", value) + or regex.search(r"\p{IsGujarati}", value) + or regex.search(r"\p{IsHanunoo}", value) + or regex.search(r"\p{IsHebrew}", value) + or regex.search(r"\p{IsLimbu}", value) + or regex.search(r"\p{IsKannada}", value) + ): return True return False diff --git a/poetry.lock b/poetry.lock index 60ce4d0ae..dcdb5290b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4519,13 +4519,13 @@ test = ["websockets"] [[package]] name = "werkzeug" -version = "3.0.3" +version = "3.0.6" description = "The comprehensive WSGI web application library." optional = false python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, - {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, + {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, + {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, ] [package.dependencies] @@ -4803,4 +4803,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.12.2" -content-hash = "42172a923e16c5b0965ab06f717d41e8491ee35f7be674091b38014c48b7a89e" +content-hash = "cf18ae74630e47eec18cc6c5fea9e554476809d20589d82c54a8d761bb2c3de0" diff --git a/pyproject.toml b/pyproject.toml index 3e3a78aed..99858c09e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ psycopg2-binary = "==2.9.9" pyjwt = "==2.8.0" python-dotenv = "==1.0.1" sqlalchemy = "==2.0.31" -werkzeug = "^3.0.3" +werkzeug = "^3.0.6" faker = "^26.0.0" async-timeout = "^4.0.3" bleach = "^6.1.0" diff --git a/tests/app/dao/notification_dao/test_notification_dao.py b/tests/app/dao/notification_dao/test_notification_dao.py index 4bc1ce5ba..e2ac10032 100644 --- a/tests/app/dao/notification_dao/test_notification_dao.py +++ b/tests/app/dao/notification_dao/test_notification_dao.py @@ -4,9 +4,11 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound +from app import db from app.dao.notifications_dao import ( dao_create_notification, dao_delete_notifications_by_id, @@ -55,7 +57,10 @@ def test_should_by_able_to_update_status_by_reference( notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.get(notification.id).status == NotificationStatus.SENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.SENDING + ) notification.reference = "reference" dao_update_notification(notification) @@ -64,7 +69,8 @@ def test_should_by_able_to_update_status_by_reference( ) assert updated.status == NotificationStatus.DELIVERED assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) @@ -81,7 +87,10 @@ def test_should_by_able_to_update_status_by_id( dao_create_notification(notification) assert notification.status == NotificationStatus.SENDING - assert Notification.query.get(notification.id).status == NotificationStatus.SENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.SENDING + ) with freeze_time("2000-01-02 12:00:00"): updated = update_notification_status_by_id( @@ -92,7 +101,8 @@ def test_should_by_able_to_update_status_by_id( assert updated.status == NotificationStatus.DELIVERED assert updated.updated_at == datetime(2000, 1, 2, 12, 0, 0) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) assert notification.updated_at == datetime(2000, 1, 2, 12, 0, 0) assert notification.status == NotificationStatus.DELIVERED @@ -107,15 +117,17 @@ def test_should_not_update_status_by_id_if_not_sending_and_does_not_update_job( job=sample_job, ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) assert not update_notification_status_by_id( notification.id, NotificationStatus.FAILED ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) - assert sample_job == Job.query.get(notification.job_id) + assert sample_job == db.session.get(Job, notification.job_id) def test_should_not_update_status_by_reference_if_not_sending_and_does_not_update_job( @@ -128,20 +140,22 @@ def test_should_not_update_status_by_reference_if_not_sending_and_does_not_updat job=sample_job, ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) assert not update_notification_status_by_reference( "reference", NotificationStatus.FAILED ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) - assert sample_job == Job.query.get(notification.job_id) + assert sample_job == db.session.get(Job, notification.job_id) def test_should_update_status_by_id_if_created(sample_template, sample_notification): assert ( - Notification.query.get(sample_notification.id).status + db.session.get(Notification, sample_notification.id).status == NotificationStatus.CREATED ) updated = update_notification_status_by_id( @@ -149,7 +163,7 @@ def test_should_update_status_by_id_if_created(sample_template, sample_notificat NotificationStatus.FAILED, ) assert ( - Notification.query.get(sample_notification.id).status + db.session.get(Notification, sample_notification.id).status == NotificationStatus.FAILED ) assert updated.status == NotificationStatus.FAILED @@ -244,11 +258,17 @@ def test_should_not_update_status_by_reference_if_not_sending(sample_template): status=NotificationStatus.CREATED, reference="reference", ) - assert Notification.query.get(notification.id).status == NotificationStatus.CREATED + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.CREATED + ) updated = update_notification_status_by_reference( "reference", NotificationStatus.FAILED ) - assert Notification.query.get(notification.id).status == NotificationStatus.CREATED + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.CREATED + ) assert not updated @@ -264,14 +284,18 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_delivered( assert update_notification_status_by_id( notification_id=notification.id, status=NotificationStatus.PENDING ) - assert Notification.query.get(notification.id).status == NotificationStatus.PENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.PENDING + ) assert update_notification_status_by_id( notification.id, NotificationStatus.DELIVERED, ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) @@ -289,7 +313,10 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_temporary_failure notification_id=notification.id, status=NotificationStatus.PENDING, ) - assert Notification.query.get(notification.id).status == NotificationStatus.PENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.PENDING + ) assert update_notification_status_by_id( notification.id, @@ -297,7 +324,7 @@ def test_should_by_able_to_update_status_by_id_from_pending_to_temporary_failure ) assert ( - Notification.query.get(notification.id).status + db.session.get(Notification, notification.id).status == NotificationStatus.TEMPORARY_FAILURE ) @@ -312,14 +339,17 @@ def test_should_by_able_to_update_status_by_id_from_sending_to_permanent_failure ) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.get(notification.id).status == NotificationStatus.SENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.SENDING + ) assert update_notification_status_by_id( notification.id, status=NotificationStatus.PERMANENT_FAILURE, ) assert ( - Notification.query.get(notification.id).status + db.session.get(Notification, notification.id).status == NotificationStatus.PERMANENT_FAILURE ) @@ -331,7 +361,10 @@ def test_should_not_update_status_once_notification_status_is_delivered( template=sample_email_template, status=NotificationStatus.SENDING, ) - assert Notification.query.get(notification.id).status == NotificationStatus.SENDING + assert ( + db.session.get(Notification, notification.id).status + == NotificationStatus.SENDING + ) notification.reference = "reference" dao_update_notification(notification) @@ -340,7 +373,8 @@ def test_should_not_update_status_once_notification_status_is_delivered( NotificationStatus.DELIVERED, ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) update_notification_status_by_reference( @@ -348,7 +382,8 @@ def test_should_not_update_status_once_notification_status_is_delivered( NotificationStatus.FAILED, ) assert ( - Notification.query.get(notification.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification.id).status + == NotificationStatus.DELIVERED ) @@ -370,7 +405,7 @@ def test_create_notification_creates_notification_with_personalisation( sample_template_with_placeholders, sample_job, ): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = create_notification( template=sample_template_with_placeholders, @@ -379,8 +414,8 @@ def test_create_notification_creates_notification_with_personalisation( status=NotificationStatus.CREATED, ) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert data.to == notification_from_db.to assert data.job_id == notification_from_db.job_id @@ -393,15 +428,15 @@ def test_create_notification_creates_notification_with_personalisation( def test_save_notification_creates_sms(sample_template, sample_job): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template, job_id=sample_job.id) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["job_id"] == notification_from_db.job_id @@ -412,16 +447,36 @@ def test_save_notification_creates_sms(sample_template, sample_job): assert notification_from_db.status == NotificationStatus.CREATED +def _get_notification_query_all(): + stmt = select(Notification) + return db.session.execute(stmt).scalars().all() + + +def _get_notification_query_one(): + stmt = select(Notification) + return db.session.execute(stmt).scalars().one() + + +def _get_notification_query_count(): + stmt = select(func.count(Notification.id)) + return db.session.execute(stmt).scalar() or 0 + + +def _get_notification_history_query_count(): + stmt = select(func.count(NotificationHistory.id)) + return db.session.execute(stmt).scalar() or 0 + + def test_save_notification_and_create_email(sample_email_template, sample_job): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_email_template, job_id=sample_job.id) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["job_id"] == notification_from_db.job_id @@ -433,29 +488,29 @@ def test_save_notification_and_create_email(sample_email_template, sample_job): def test_save_notification(sample_email_template, sample_job): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_email_template, job_id=sample_job.id) notification_1 = Notification(**data) notification_2 = Notification(**data) dao_create_notification(notification_1) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 dao_create_notification(notification_2) - assert Notification.query.count() == 2 + assert _get_notification_query_count() == 2 def test_save_notification_does_not_creates_history(sample_email_template, sample_job): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_email_template, job_id=sample_job.id) notification_1 = Notification(**data) dao_create_notification(notification_1) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 def test_update_notification_with_research_mode_service_does_not_create_or_update_history( @@ -464,14 +519,14 @@ def test_update_notification_with_research_mode_service_does_not_create_or_updat sample_template.service.research_mode = True notification = create_notification(template=sample_template) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 notification.status = NotificationStatus.DELIVERED dao_update_notification(notification) - assert Notification.query.one().status == NotificationStatus.DELIVERED - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_one().status == NotificationStatus.DELIVERED + assert _get_notification_history_query_count() == 0 def test_not_save_notification_and_not_create_stats_on_commit_error( @@ -479,26 +534,26 @@ def test_not_save_notification_and_not_create_stats_on_commit_error( ): random_id = str(uuid.uuid4()) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template, job_id=random_id) notification = Notification(**data) with pytest.raises(SQLAlchemyError): dao_create_notification(notification) - assert Notification.query.count() == 0 - assert Job.query.get(sample_job.id).notifications_sent == 0 + assert _get_notification_query_count() == 0 + assert db.session.get(Job, sample_job.id).notifications_sent == 0 def test_save_notification_and_increment_job(sample_template, sample_job, sns_provider): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template, job_id=sample_job.id) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["job_id"] == notification_from_db.job_id @@ -510,21 +565,21 @@ def test_save_notification_and_increment_job(sample_template, sample_job, sns_pr notification_2 = Notification(**data) dao_create_notification(notification_2) - assert Notification.query.count() == 2 + assert _get_notification_query_count() == 2 def test_save_notification_and_increment_correct_job(sample_template, sns_provider): job_1 = create_job(sample_template) job_2 = create_job(sample_template) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template, job_id=job_1.id) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["job_id"] == notification_from_db.job_id @@ -537,14 +592,14 @@ def test_save_notification_and_increment_correct_job(sample_template, sns_provid def test_save_notification_with_no_job(sample_template, sns_provider): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["service"] == notification_from_db.service @@ -592,7 +647,7 @@ def test_get_notification_by_id_when_notification_exists_for_different_service( def test_get_notifications_by_reference(sample_template): client_reference = "some-client-ref" - assert len(Notification.query.all()) == 0 + assert len(_get_notification_query_all()) == 0 create_notification(sample_template, client_reference=client_reference) create_notification(sample_template, client_reference=client_reference) create_notification(sample_template, client_reference="other-ref") @@ -603,14 +658,14 @@ def test_get_notifications_by_reference(sample_template): def test_save_notification_no_job_id(sample_template): - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 data = _notification_json(sample_template) notification = Notification(**data) dao_create_notification(notification) - assert Notification.query.count() == 1 - notification_from_db = Notification.query.all()[0] + assert _get_notification_query_count() == 1 + notification_from_db = _get_notification_query_all()[0] assert notification_from_db.id assert "1" == notification_from_db.to assert data["service"] == notification_from_db.service @@ -687,13 +742,13 @@ def test_update_notification_sets_status(sample_notification): assert sample_notification.status == NotificationStatus.CREATED sample_notification.status = NotificationStatus.FAILED dao_update_notification(sample_notification) - notification_from_db = Notification.query.get(sample_notification.id) + notification_from_db = db.session.get(Notification, sample_notification.id) assert notification_from_db.status == NotificationStatus.FAILED @freeze_time("2016-01-10") def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template): - assert len(Notification.query.all()) == 0 + assert len(_get_notification_query_all()) == 0 # create one notification a day between 1st and 9th, # with assumption that the local timezone is EST @@ -706,7 +761,7 @@ def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template status=NotificationStatus.FAILED, ) - all_notifications = Notification.query.all() + all_notifications = _get_notification_query_all() assert len(all_notifications) == 10 all_notifications = get_notifications_for_service( @@ -722,19 +777,19 @@ def test_should_limit_notifications_return_by_day_limit_plus_one(sample_template def test_creating_notification_does_not_add_notification_history(sample_template): create_notification(template=sample_template) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 def test_should_delete_notification_for_id(sample_template): notification = create_notification(template=sample_template) - assert Notification.query.count() == 1 - assert NotificationHistory.query.count() == 0 + assert _get_notification_query_count() == 1 + assert _get_notification_history_query_count() == 0 dao_delete_notifications_by_id(notification.id) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_should_delete_notification_and_ignore_history_for_research_mode( @@ -744,31 +799,32 @@ def test_should_delete_notification_and_ignore_history_for_research_mode( notification = create_notification(template=sample_template) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 dao_delete_notifications_by_id(notification.id) - assert Notification.query.count() == 0 + assert _get_notification_query_count() == 0 def test_should_delete_only_notification_with_id(sample_template): notification_1 = create_notification(template=sample_template) notification_2 = create_notification(template=sample_template) - assert Notification.query.count() == 2 + assert _get_notification_query_count() == 2 dao_delete_notifications_by_id(notification_1.id) - assert Notification.query.count() == 1 - assert Notification.query.first().id == notification_2.id + assert _get_notification_query_count() == 1 + stmt = select(Notification) + assert db.session.execute(stmt).scalars().first().id == notification_2.id def test_should_delete_no_notifications_if_no_matching_ids(sample_template): create_notification(template=sample_template) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 dao_delete_notifications_by_id(uuid.uuid4()) - assert Notification.query.count() == 1 + assert _get_notification_query_count() == 1 def _notification_json(sample_template, job_id=None, id=None, status=None): @@ -814,16 +870,19 @@ def test_dao_timeout_notifications(sample_template): temporary_failure_notifications = dao_timeout_notifications(utc_now()) assert len(temporary_failure_notifications) == 2 - assert Notification.query.get(created.id).status == NotificationStatus.CREATED + assert db.session.get(Notification, created.id).status == NotificationStatus.CREATED assert ( - Notification.query.get(sending.id).status + db.session.get(Notification, sending.id).status == NotificationStatus.TEMPORARY_FAILURE ) assert ( - Notification.query.get(pending.id).status + db.session.get(Notification, pending.id).status == NotificationStatus.TEMPORARY_FAILURE ) - assert Notification.query.get(delivered.id).status == NotificationStatus.DELIVERED + assert ( + db.session.get(Notification, delivered.id).status + == NotificationStatus.DELIVERED + ) def test_dao_timeout_notifications_only_updates_for_older_notifications( @@ -842,8 +901,8 @@ def test_dao_timeout_notifications_only_updates_for_older_notifications( temporary_failure_notifications = dao_timeout_notifications(utc_now()) assert len(temporary_failure_notifications) == 0 - assert Notification.query.get(sending.id).status == NotificationStatus.SENDING - assert Notification.query.get(pending.id).status == NotificationStatus.PENDING + assert db.session.get(Notification, sending.id).status == NotificationStatus.SENDING + assert db.session.get(Notification, pending.id).status == NotificationStatus.PENDING def test_should_return_notifications_excluding_jobs_by_default( @@ -935,7 +994,7 @@ def test_get_notifications_created_by_api_or_csv_are_returned_correctly_excludin key_type=sample_test_api_key.key_type, ) - all_notifications = Notification.query.all() + all_notifications = _get_notification_query_all() assert len(all_notifications) == 4 # returns all real API derived notifications @@ -982,7 +1041,7 @@ def test_get_notifications_with_a_live_api_key_type( key_type=sample_test_api_key.key_type, ) - all_notifications = Notification.query.all() + all_notifications = _get_notification_query_all() assert len(all_notifications) == 4 # only those created with normal API key, no jobs @@ -1114,7 +1173,7 @@ def test_should_exclude_test_key_notifications_by_default( key_type=sample_test_api_key.key_type, ) - all_notifications = Notification.query.all() + all_notifications = _get_notification_query_all() assert len(all_notifications) == 4 all_notifications = get_notifications_for_service( @@ -1757,10 +1816,10 @@ def test_dao_update_notifications_by_reference_updated_notifications(sample_temp update_dict={"status": NotificationStatus.DELIVERED, "billable_units": 2}, ) assert updated_count == 2 - updated_1 = Notification.query.get(notification_1.id) + updated_1 = db.session.get(Notification, notification_1.id) assert updated_1.billable_units == 2 assert updated_1.status == NotificationStatus.DELIVERED - updated_2 = Notification.query.get(notification_2.id) + updated_2 = db.session.get(Notification, notification_2.id) assert updated_2.billable_units == 2 assert updated_2.status == NotificationStatus.DELIVERED @@ -1823,10 +1882,11 @@ def test_dao_update_notifications_by_reference_updates_history_when_one_of_two_n assert updated_count == 1 assert updated_history_count == 1 assert ( - Notification.query.get(notification2.id).status == NotificationStatus.DELIVERED + db.session.get(Notification, notification2.id).status + == NotificationStatus.DELIVERED ) assert ( - NotificationHistory.query.get(notification1.id).status + db.session.get(NotificationHistory, notification1.id).status == NotificationStatus.DELIVERED ) diff --git a/tests/app/dao/test_organization_dao.py b/tests/app/dao/test_organization_dao.py index edffdd1d4..fb2e01d85 100644 --- a/tests/app/dao/test_organization_dao.py +++ b/tests/app/dao/test_organization_dao.py @@ -1,6 +1,7 @@ import uuid import pytest +from sqlalchemy import select from sqlalchemy.exc import IntegrityError, SQLAlchemyError from app import db @@ -57,7 +58,8 @@ def test_get_organization_by_id_gets_correct_organization(notify_db_session): def test_update_organization(notify_db_session): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() user = create_user() email_branding = create_email_branding() @@ -78,7 +80,8 @@ def test_update_organization(notify_db_session): dao_update_organization(organization.id, **data) - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() for attribute, value in data.items(): assert getattr(organization, attribute) == value @@ -102,7 +105,8 @@ def test_update_organization_domains_lowercases( ): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() # Seed some domains dao_update_organization(organization.id, domains=["123", "456"]) @@ -121,7 +125,8 @@ def test_update_organization_domains_lowercases_integrity_error( ): create_organization() - organization = Organization.query.one() + stmt = select(Organization) + organization = db.session.execute(stmt).scalars().one() # Seed some domains dao_update_organization(organization.id, domains=["123", "456"]) @@ -175,11 +180,11 @@ def test_update_organization_updates_the_service_org_type_if_org_type_is_provide assert sample_organization.organization_type == OrganizationType.FEDERAL assert sample_service.organization_type == OrganizationType.FEDERAL + stmt = select(Service.get_history_model()).filter_by( + id=sample_service.id, version=2 + ) assert ( - Service.get_history_model() - .query.filter_by(id=sample_service.id, version=2) - .one() - .organization_type + db.session.execute(stmt).scalars().one().organization_type == OrganizationType.FEDERAL ) @@ -229,11 +234,11 @@ def test_add_service_to_organization(sample_service, sample_organization): assert sample_organization.services[0].id == sample_service.id assert sample_service.organization_type == sample_organization.organization_type + stmt = select(Service.get_history_model()).filter_by( + id=sample_service.id, version=2 + ) assert ( - Service.get_history_model() - .query.filter_by(id=sample_service.id, version=2) - .one() - .organization_type + db.session.execute(stmt).scalars().one().organization_type == sample_organization.organization_type ) assert sample_service.organization_id == sample_organization.id diff --git a/tests/app/dao/test_service_sms_sender_dao.py b/tests/app/dao/test_service_sms_sender_dao.py index 9ca05e711..10bfd21f4 100644 --- a/tests/app/dao/test_service_sms_sender_dao.py +++ b/tests/app/dao/test_service_sms_sender_dao.py @@ -1,8 +1,10 @@ import uuid import pytest +from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from app import db from app.dao.service_sms_sender_dao import ( archive_sms_sender, dao_add_sms_sender_for_service, @@ -97,10 +99,8 @@ def test_dao_add_sms_sender_for_service(notify_db_session): is_default=False, inbound_number_id=None, ) - - service_sms_senders = ServiceSmsSender.query.order_by( - ServiceSmsSender.created_at - ).all() + stmt = select(ServiceSmsSender).order_by(ServiceSmsSender.created_at) + service_sms_senders = db.session.execute(stmt).scalars().all() assert len(service_sms_senders) == 2 assert service_sms_senders[0].sms_sender == "testing" assert service_sms_senders[0].is_default @@ -116,10 +116,8 @@ def test_dao_add_sms_sender_for_service_switches_default(notify_db_session): is_default=True, inbound_number_id=None, ) - - service_sms_senders = ServiceSmsSender.query.order_by( - ServiceSmsSender.created_at - ).all() + stmt = select(ServiceSmsSender).order_by(ServiceSmsSender.created_at) + service_sms_senders = db.session.execute(stmt).scalars().all() assert len(service_sms_senders) == 2 assert service_sms_senders[0].sms_sender == "testing" assert not service_sms_senders[0].is_default @@ -128,7 +126,8 @@ def test_dao_add_sms_sender_for_service_switches_default(notify_db_session): def test_dao_update_service_sms_sender(notify_db_session): service = create_service() - service_sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all() + stmt = select(ServiceSmsSender).filter_by(service_id=service.id) + service_sms_senders = db.session.execute(stmt).scalars().all() assert len(service_sms_senders) == 1 sms_sender_to_update = service_sms_senders[0] @@ -138,7 +137,8 @@ def test_dao_update_service_sms_sender(notify_db_session): is_default=True, sms_sender="updated", ) - sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all() + stmt = select(ServiceSmsSender).filter_by(service_id=service.id) + sms_senders = db.session.execute(stmt).scalars().all() assert len(sms_senders) == 1 assert sms_senders[0].is_default assert sms_senders[0].sms_sender == "updated" @@ -159,7 +159,8 @@ def test_dao_update_service_sms_sender_switches_default(notify_db_session): is_default=True, sms_sender="updated", ) - sms_senders = ServiceSmsSender.query.filter_by(service_id=service.id).all() + stmt = select(ServiceSmsSender).filter_by(service_id=service.id) + sms_senders = db.session.execute(stmt).scalars().all() expected = {("testing", False), ("updated", True)} results = {(sender.sms_sender, sender.is_default) for sender in sms_senders} @@ -190,7 +191,8 @@ def test_update_existing_sms_sender_with_inbound_number(notify_db_session): service = create_service() inbound_number = create_inbound_number(number="12345", service_id=service.id) - existing_sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).one() + stmt = select(ServiceSmsSender).filter_by(service_id=service.id) + existing_sms_sender = db.session.execute(stmt).scalars().one() sms_sender = update_existing_sms_sender_with_inbound_number( service_sms_sender=existing_sms_sender, sms_sender=inbound_number.number, @@ -206,7 +208,8 @@ def test_update_existing_sms_sender_with_inbound_number_raises_exception_if_inbo notify_db_session, ): service = create_service() - existing_sms_sender = ServiceSmsSender.query.filter_by(service_id=service.id).one() + stmt = select(ServiceSmsSender).filter_by(service_id=service.id) + existing_sms_sender = db.session.execute(stmt).scalars().one() with pytest.raises(expected_exception=SQLAlchemyError): update_existing_sms_sender_with_inbound_number( service_sms_sender=existing_sms_sender, diff --git a/tests/app/dao/test_services_dao.py b/tests/app/dao/test_services_dao.py index e590eb5b4..61fe99419 100644 --- a/tests/app/dao/test_services_dao.py +++ b/tests/app/dao/test_services_dao.py @@ -6,6 +6,7 @@ import pytest import sqlalchemy from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound @@ -89,9 +90,32 @@ ) +def _get_service_query_count(): + stmt = select(func.count(Service.id)) + return db.session.execute(stmt).scalar() or 0 + + +def _get_service_history_query_count(): + stmt = select(func.count(Service.get_history_model().id)) + return db.session.execute(stmt).scalar() or 0 + + +def _get_first_service(): + stmt = select(Service).limit(1) + service = db.session.execute(stmt).scalars().first() + return service + + +def _get_service_by_id(service_id): + stmt = select(Service).filter(Service.id == service_id) + + service = db.session.execute(stmt).scalars().one() + return service + + def test_create_service(notify_db_session): user = create_user() - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -101,8 +125,8 @@ def test_create_service(notify_db_session): created_by=user, ) dao_create_service(service, user) - assert Service.query.count() == 1 - service_db = Service.query.one() + assert _get_service_query_count() == 1 + service_db = _get_first_service() assert service_db.name == "service_name" assert service_db.id == service.id assert service_db.email_from == "email_from" @@ -120,7 +144,7 @@ def test_create_service_with_organization(notify_db_session): organization_type=OrganizationType.STATE, domains=["local-authority.gov.uk"], ) - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -130,9 +154,9 @@ def test_create_service_with_organization(notify_db_session): created_by=user, ) dao_create_service(service, user) - assert Service.query.count() == 1 - service_db = Service.query.one() - organization = Organization.query.get(organization.id) + assert _get_service_query_count() == 1 + service_db = _get_first_service() + organization = db.session.get(Organization, organization.id) assert service_db.name == "service_name" assert service_db.id == service.id assert service_db.email_from == "email_from" @@ -151,7 +175,7 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session): organization_type=OrganizationType.STATE, domains=["local-authority.gov.uk"], ) - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -161,9 +185,9 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session): created_by=user, ) dao_create_service(service, user) - assert Service.query.count() == 1 - service_db = Service.query.one() - organization = Organization.query.get(organization.id) + assert _get_service_query_count() == 1 + service_db = _get_first_service() + organization = db.session.get(Organization, organization.id) assert service_db.name == "service_name" assert service_db.id == service.id assert service_db.email_from == "email_from" @@ -183,7 +207,7 @@ def test_fetch_service_by_id_with_api_keys(notify_db_session): def test_cannot_create_two_services_with_same_name(notify_db_session): user = create_user() - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service1 = Service( name="service_name", email_from="email_from1", @@ -209,7 +233,7 @@ def test_cannot_create_two_services_with_same_name(notify_db_session): def test_cannot_create_two_services_with_same_email_from(notify_db_session): user = create_user() - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service1 = Service( name="service_name1", email_from="email_from", @@ -235,7 +259,7 @@ def test_cannot_create_two_services_with_same_email_from(notify_db_session): def test_cannot_create_service_with_no_user(notify_db_session): user = create_user() - assert Service.query.count() == 0 + assert _get_service_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -258,7 +282,7 @@ def test_should_add_user_to_service(notify_db_session): created_by=user, ) dao_create_service(service, user) - assert user in Service.query.first().users + assert user in _get_first_service().users new_user = User( name="Test User", email_address="new_user@digital.fake.gov", @@ -267,7 +291,7 @@ def test_should_add_user_to_service(notify_db_session): ) save_model_user(new_user, validated_email_access=True) dao_add_user_to_service(service, new_user) - assert new_user in Service.query.first().users + assert new_user in _get_first_service().users def test_dao_add_user_to_service_sets_folder_permissions(sample_user, sample_service): @@ -314,7 +338,8 @@ def test_dao_add_user_to_service_raises_error_if_adding_folder_permissions_for_a other_service_folder = create_template_folder(other_service) folder_permissions = [str(other_service_folder.id)] - assert ServiceUser.query.count() == 2 + stmt = select(func.count(ServiceUser.service_id)) + assert db.session.execute(stmt).scalar() == 2 with pytest.raises(IntegrityError) as e: dao_add_user_to_service( @@ -326,7 +351,8 @@ def test_dao_add_user_to_service_raises_error_if_adding_folder_permissions_for_a 'insert or update on table "user_folder_permissions" violates foreign key constraint' in str(e.value) ) - assert ServiceUser.query.count() == 2 + stmt = select(func.count(ServiceUser.service_id)) + assert db.session.execute(stmt).scalar() == 2 def test_should_remove_user_from_service(notify_db_session): @@ -347,9 +373,9 @@ def test_should_remove_user_from_service(notify_db_session): ) save_model_user(new_user, validated_email_access=True) dao_add_user_to_service(service, new_user) - assert new_user in Service.query.first().users + assert new_user in _get_first_service().users dao_remove_user_from_service(service, new_user) - assert new_user not in Service.query.first().users + assert new_user not in _get_first_service().users def test_should_remove_user_from_service_exception(notify_db_session): @@ -382,11 +408,12 @@ def test_should_remove_user_from_service_exception(notify_db_session): def test_removing_a_user_from_a_service_deletes_their_permissions( sample_user, sample_service ): - assert len(Permission.query.all()) == 7 + stmt = select(Permission) + assert len(db.session.execute(stmt).all()) == 7 dao_remove_user_from_service(sample_service, sample_user) - assert Permission.query.all() == [] + assert db.session.execute(stmt).all() == [] def test_removing_a_user_from_a_service_deletes_their_folder_permissions_for_that_service( @@ -668,8 +695,8 @@ def test_removing_all_permission_returns_service_with_no_permissions(notify_db_s def test_create_service_creates_a_history_record_with_current_data(notify_db_session): user = create_user() - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -678,11 +705,12 @@ def test_create_service_creates_a_history_record_with_current_data(notify_db_ses created_by=user, ) dao_create_service(service, user) - assert Service.query.count() == 1 - assert Service.get_history_model().query.count() == 1 + assert _get_service_query_count() == 1 + assert _get_service_history_query_count() == 1 - service_from_db = Service.query.first() - service_history = Service.get_history_model().query.first() + service_from_db = _get_first_service() + stmt = select(Service.get_history_model()) + service_history = db.session.execute(stmt).scalars().first() assert service_from_db.id == service_history.id assert service_from_db.name == service_history.name @@ -694,8 +722,8 @@ def test_create_service_creates_a_history_record_with_current_data(notify_db_ses def test_update_service_creates_a_history_record_with_current_data(notify_db_session): user = create_user() - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -705,39 +733,31 @@ def test_update_service_creates_a_history_record_with_current_data(notify_db_ses ) dao_create_service(service, user) - assert Service.query.count() == 1 - assert Service.query.first().version == 1 - assert Service.get_history_model().query.count() == 1 + assert _get_service_query_count() == 1 + assert _get_first_service().version == 1 + assert _get_service_history_query_count() == 1 service.name = "updated_service_name" dao_update_service(service) - assert Service.query.count() == 1 - assert Service.get_history_model().query.count() == 2 + assert _get_service_query_count() == 1 + assert _get_service_history_query_count() == 2 - service_from_db = Service.query.first() + service_from_db = _get_first_service() assert service_from_db.version == 2 - - assert ( - Service.get_history_model().query.filter_by(name="service_name").one().version - == 1 - ) - assert ( - Service.get_history_model() - .query.filter_by(name="updated_service_name") - .one() - .version - == 2 - ) + stmt = select(Service.get_history_model()).filter_by(name="service_name") + assert db.session.execute(stmt).scalars().one().version == 1 + stmt = select(Service.get_history_model()).filter_by(name="updated_service_name") + assert db.session.execute(stmt).scalars().one().version == 2 def test_update_service_permission_creates_a_history_record_with_current_data( notify_db_session, ): user = create_user() - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 service = Service( name="service_name", email_from="email_from", @@ -755,17 +775,17 @@ def test_update_service_permission_creates_a_history_record_with_current_data( ], ) - assert Service.query.count() == 1 + assert _get_service_query_count() == 1 service.permissions.append( ServicePermission(service_id=service.id, permission=ServicePermissionType.EMAIL) ) dao_update_service(service) - assert Service.query.count() == 1 - assert Service.get_history_model().query.count() == 2 + assert _get_service_query_count() == 1 + assert _get_service_history_query_count() == 2 - service_from_db = Service.query.first() + service_from_db = _get_first_service() assert service_from_db.version == 2 @@ -784,10 +804,10 @@ def test_update_service_permission_creates_a_history_record_with_current_data( service.permissions.remove(permission) dao_update_service(service) - assert Service.query.count() == 1 - assert Service.get_history_model().query.count() == 3 + assert _get_service_query_count() == 1 + assert _get_service_history_query_count() == 3 - service_from_db = Service.query.first() + service_from_db = _get_first_service() assert service_from_db.version == 3 _assert_service_permissions( service.permissions, @@ -797,21 +817,20 @@ def test_update_service_permission_creates_a_history_record_with_current_data( ), ) - history = ( - Service.get_history_model() - .query.filter_by(name="service_name") + stmt = ( + select(Service.get_history_model()) + .filter_by(name="service_name") .order_by("version") - .all() ) - + history = db.session.execute(stmt).scalars().all() assert len(history) == 3 assert history[2].version == 3 def test_create_service_and_history_is_transactional(notify_db_session): user = create_user() - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 service = Service( name=None, email_from="email_from", @@ -828,8 +847,8 @@ def test_create_service_and_history_is_transactional(notify_db_session): in str(seeei) ) - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 def test_delete_service_and_associated_objects(notify_db_session): @@ -845,8 +864,8 @@ def test_delete_service_and_associated_objects(notify_db_session): create_notification(template=template, api_key=api_key) create_invited_user(service=service) user.organizations = [organization] - - assert ServicePermission.query.count() == len( + stmt = select(func.count(ServicePermission.service_id)) + assert db.session.execute(stmt).scalar() == len( ( ServicePermissionType.SMS, ServicePermissionType.EMAIL, @@ -855,21 +874,35 @@ def test_delete_service_and_associated_objects(notify_db_session): ) delete_service_and_all_associated_db_objects(service) - assert VerifyCode.query.count() == 0 - assert ApiKey.query.count() == 0 - assert ApiKey.get_history_model().query.count() == 0 - assert Template.query.count() == 0 - assert TemplateHistory.query.count() == 0 - assert Job.query.count() == 0 - assert Notification.query.count() == 0 - assert Permission.query.count() == 0 - assert User.query.count() == 0 - assert InvitedUser.query.count() == 0 - assert Service.query.count() == 0 - assert Service.get_history_model().query.count() == 0 - assert ServicePermission.query.count() == 0 + stmt = select(VerifyCode) + assert db.session.execute(stmt).scalar() is None + stmt = select(ApiKey) + assert db.session.execute(stmt).scalar() is None + stmt = select(ApiKey.get_history_model()) + assert db.session.execute(stmt).scalar() is None + stmt = select(Template) + assert db.session.execute(stmt).scalar() is None + stmt = select(TemplateHistory) + assert db.session.execute(stmt).scalar() is None + stmt = select(Job) + assert db.session.execute(stmt).scalar() is None + stmt = select(Notification) + assert db.session.execute(stmt).scalar() is None + stmt = select(Permission) + assert db.session.execute(stmt).scalar() is None + stmt = select(User) + assert db.session.execute(stmt).scalar() is None + stmt = select(InvitedUser) + assert db.session.execute(stmt).scalar() is None + + assert _get_service_query_count() == 0 + assert _get_service_history_query_count() == 0 + stmt = select(ServicePermission) + assert db.session.execute(stmt).scalar() is None + # the organization hasn't been deleted - assert Organization.query.count() == 1 + stmt = select(func.count(Organization.id)) + assert db.session.execute(stmt).scalar() == 1 def test_add_existing_user_to_another_service_doesnot_change_old_permissions( @@ -887,9 +920,8 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions( dao_create_service(service_one, user) assert user.id == service_one.users[0].id - test_user_permissions = Permission.query.filter_by( - service=service_one, user=user - ).all() + stmt = select(Permission).filter_by(service=service_one, user=user) + test_user_permissions = db.session.execute(stmt).all() assert len(test_user_permissions) == 7 other_user = User( @@ -909,14 +941,12 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions( dao_create_service(service_two, other_user) assert other_user.id == service_two.users[0].id - other_user_permissions = Permission.query.filter_by( - service=service_two, user=other_user - ).all() + stmt = select(Permission).filter_by(service=service_two, user=other_user) + other_user_permissions = db.session.execute(stmt).all() assert len(other_user_permissions) == 7 + stmt = select(Permission).filter_by(service=service_one, user=other_user) + other_user_service_one_permissions = db.session.execute(stmt).all() - other_user_service_one_permissions = Permission.query.filter_by( - service=service_one, user=other_user - ).all() assert len(other_user_service_one_permissions) == 0 # adding the other_user to service_one should leave all other_user permissions on service_two intact @@ -925,15 +955,12 @@ def test_add_existing_user_to_another_service_doesnot_change_old_permissions( permissions.append(Permission(permission=p)) dao_add_user_to_service(service_one, other_user, permissions=permissions) - - other_user_service_one_permissions = Permission.query.filter_by( - service=service_one, user=other_user - ).all() + stmt = select(Permission).filter_by(service=service_one, user=other_user) + other_user_service_one_permissions = db.session.execute(stmt).all() assert len(other_user_service_one_permissions) == 2 - other_user_service_two_permissions = Permission.query.filter_by( - service=service_two, user=other_user - ).all() + stmt = select(Permission).filter_by(service=service_two, user=other_user) + other_user_service_two_permissions = db.session.execute(stmt).all() assert len(other_user_service_two_permissions) == 7 @@ -956,9 +983,10 @@ def test_fetch_stats_filters_on_service(notify_db_session): def test_fetch_stats_ignores_historical_notification_data(sample_template): create_notification_history(template=sample_template) - - assert Notification.query.count() == 0 - assert NotificationHistory.query.count() == 1 + stmt = select(func.count(Notification.id)) + assert db.session.execute(stmt).scalar() == 0 + stmt = select(func.count(NotificationHistory.id)) + assert db.session.execute(stmt).scalar() == 1 stats = dao_fetch_todays_stats_for_service(sample_template.service_id) assert len(stats) == 0 @@ -1316,7 +1344,7 @@ def test_dao_fetch_todays_stats_for_all_services_can_exclude_from_test_key( def test_dao_suspend_service_with_no_api_keys(notify_db_session): service = create_service() dao_suspend_service(service.id) - service = Service.query.get(service.id) + service = _get_service_by_id(service.id) assert not service.active assert service.name == service.name assert service.api_keys == [] @@ -1329,11 +1357,11 @@ def test_dao_suspend_service_marks_service_as_inactive_and_expires_api_keys( service = create_service() api_key = create_api_key(service=service) dao_suspend_service(service.id) - service = Service.query.get(service.id) + service = _get_service_by_id(service.id) assert not service.active assert service.name == service.name - api_key = ApiKey.query.get(api_key.id) + api_key = db.session.get(ApiKey, api_key.id) assert api_key.expiry_date == datetime(2001, 1, 1, 23, 59, 00) @@ -1344,13 +1372,13 @@ def test_dao_resume_service_marks_service_as_active_and_api_keys_are_still_revok service = create_service() api_key = create_api_key(service=service) dao_suspend_service(service.id) - service = Service.query.get(service.id) + service = _get_service_by_id(service.id) assert not service.active dao_resume_service(service.id) - assert Service.query.get(service.id).active + assert _get_service_by_id(service.id).active - api_key = ApiKey.query.get(api_key.id) + api_key = db.session.get(ApiKey, api_key.id) assert api_key.expiry_date == datetime(2001, 1, 1, 23, 59, 00) diff --git a/tests/app/dao/test_template_folder_dao.py b/tests/app/dao/test_template_folder_dao.py index 17b03e5df..2a872e775 100644 --- a/tests/app/dao/test_template_folder_dao.py +++ b/tests/app/dao/test_template_folder_dao.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from app import db from app.dao.service_user_dao import dao_get_service_user from app.dao.template_folder_dao import ( @@ -17,5 +19,5 @@ def test_dao_delete_template_folder_deletes_user_folder_permissions( dao_update_template_folder(folder) dao_delete_template_folder(folder) - - assert db.session.query(user_folder_permissions).all() == [] + stmt = select(user_folder_permissions) + assert db.session.execute(stmt).scalars().all() == [] diff --git a/tests/app/dao/test_templates_dao.py b/tests/app/dao/test_templates_dao.py index bfe0e59d1..734a29c0a 100644 --- a/tests/app/dao/test_templates_dao.py +++ b/tests/app/dao/test_templates_dao.py @@ -2,8 +2,10 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.orm.exc import NoResultFound +from app import db from app.dao.templates_dao import ( dao_create_template, dao_get_all_templates_for_service, @@ -17,6 +19,16 @@ from tests.app.db import create_template +def template_query_count(): + stmt = select(func.count()).select_from(Template) + return db.session.execute(stmt).scalar() or 0 + + +def template_history_query_count(): + stmt = select(func.count()).select_from(TemplateHistory) + return db.session.execute(stmt).scalar() or 0 + + @pytest.mark.parametrize( "template_type, subject", [ @@ -37,7 +49,7 @@ def test_create_template(sample_service, sample_user, template_type, subject): template = Template(**data) dao_create_template(template) - assert Template.query.count() == 1 + assert template_query_count() == 1 assert len(dao_get_all_templates_for_service(sample_service.id)) == 1 assert ( dao_get_all_templates_for_service(sample_service.id)[0].name @@ -50,11 +62,13 @@ def test_create_template(sample_service, sample_user, template_type, subject): def test_create_template_creates_redact_entry(sample_service): - assert TemplateRedacted.query.count() == 0 + stmt = select(func.count()).select_from(TemplateRedacted) + assert db.session.execute(stmt).scalar() == 0 template = create_template(sample_service) - redacted = TemplateRedacted.query.one() + stmt = select(TemplateRedacted) + redacted = db.session.execute(stmt).scalars().one() assert redacted.template_id == template.id assert redacted.redact_personalisation is False assert redacted.updated_by_id == sample_service.created_by_id @@ -79,7 +93,8 @@ def test_update_template(sample_service, sample_user): def test_redact_template(sample_template): - redacted = TemplateRedacted.query.one() + stmt = select(TemplateRedacted) + redacted = db.session.execute(stmt).scalars().one() assert redacted.template_id == sample_template.id assert redacted.redact_personalisation is False @@ -96,7 +111,7 @@ def test_get_all_templates_for_service(service_factory): service_1 = service_factory.get("service 1", email_from="service.1") service_2 = service_factory.get("service 2", email_from="service.2") - assert Template.query.count() == 2 + assert template_query_count() == 2 assert len(dao_get_all_templates_for_service(service_1.id)) == 1 assert len(dao_get_all_templates_for_service(service_2.id)) == 1 @@ -119,7 +134,7 @@ def test_get_all_templates_for_service(service_factory): content="Template content", ) - assert Template.query.count() == 5 + assert template_query_count() == 5 assert len(dao_get_all_templates_for_service(service_1.id)) == 3 assert len(dao_get_all_templates_for_service(service_2.id)) == 2 @@ -144,7 +159,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service): service=sample_service, ) - assert Template.query.count() == 3 + assert template_query_count() == 3 assert ( dao_get_all_templates_for_service(sample_service.id)[0].name == "Sample Template 1" @@ -171,7 +186,7 @@ def test_get_all_templates_for_service_is_alphabetised(sample_service): def test_get_all_returns_empty_list_if_no_templates(sample_service): - assert Template.query.count() == 0 + assert template_query_count() == 0 assert len(dao_get_all_templates_for_service(sample_service.id)) == 0 @@ -257,8 +272,8 @@ def test_get_template_by_id_and_service_returns_none_if_no_template( def test_create_template_creates_a_history_record_with_current_data( sample_service, sample_user ): - assert Template.query.count() == 0 - assert TemplateHistory.query.count() == 0 + assert template_query_count() == 0 + assert template_history_query_count() == 0 data = { "name": "Sample Template", "template_type": TemplateType.EMAIL, @@ -270,10 +285,12 @@ def test_create_template_creates_a_history_record_with_current_data( template = Template(**data) dao_create_template(template) - assert Template.query.count() == 1 + assert template_query_count() == 1 - template_from_db = Template.query.first() - template_history = TemplateHistory.query.first() + stmt = select(Template) + template_from_db = db.session.execute(stmt).scalars().first() + stmt = select(TemplateHistory) + template_history = db.session.execute(stmt).scalars().first() assert template_from_db.id == template_history.id assert template_from_db.name == template_history.name @@ -286,8 +303,8 @@ def test_create_template_creates_a_history_record_with_current_data( def test_update_template_creates_a_history_record_with_current_data( sample_service, sample_user ): - assert Template.query.count() == 0 - assert TemplateHistory.query.count() == 0 + assert template_query_count() == 0 + assert template_history_query_count() == 0 data = { "name": "Sample Template", "template_type": TemplateType.EMAIL, @@ -301,22 +318,26 @@ def test_update_template_creates_a_history_record_with_current_data( created = dao_get_all_templates_for_service(sample_service.id)[0] assert created.name == "Sample Template" - assert Template.query.count() == 1 - assert Template.query.first().version == 1 - assert TemplateHistory.query.count() == 1 + assert template_query_count() == 1 + stmt = select(Template) + assert db.session.execute(stmt).scalars().first().version == 1 + assert template_history_query_count() == 1 created.name = "new name" dao_update_template(created) - assert Template.query.count() == 1 - assert TemplateHistory.query.count() == 2 + assert template_query_count() == 1 + assert template_history_query_count() == 2 - template_from_db = Template.query.first() + stmt = select(Template) + template_from_db = db.session.execute(stmt).scalars().first() assert template_from_db.version == 2 - assert TemplateHistory.query.filter_by(name="Sample Template").one().version == 1 - assert TemplateHistory.query.filter_by(name="new name").one().version == 2 + stmt = select(TemplateHistory).filter_by(name="Sample Template") + assert db.session.execute(stmt).scalars().one().version == 1 + stmt = select(TemplateHistory).filter_by(name="new name") + assert db.session.execute(stmt).scalars().one().version == 2 def test_get_template_history_version(sample_user, sample_service, sample_template): diff --git a/tests/app/dao/test_users_dao.py b/tests/app/dao/test_users_dao.py index 9c8770913..8f9f21fe3 100644 --- a/tests/app/dao/test_users_dao.py +++ b/tests/app/dao/test_users_dao.py @@ -3,6 +3,7 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import DataError from sqlalchemy.orm.exc import NoResultFound @@ -37,6 +38,21 @@ ) +def _get_user_query_count(): + stmt = select(func.count(User.id)) + return db.session.execute(stmt).scalar() or 0 + + +def _get_user_query_first(): + stmt = select(User) + return db.session.execute(stmt).scalars().first() + + +def _get_verify_code_query_count(): + stmt = select(func.count(VerifyCode.id)) + return db.session.execute(stmt).scalar() or 0 + + @freeze_time("2020-01-28T12:00:00") @pytest.mark.parametrize( "phone_number, expected_phone_number", @@ -55,8 +71,10 @@ def test_create_user(notify_db_session, phone_number, expected_phone_number): } user = User(**data) save_model_user(user, password="password", validated_email_access=True) - assert User.query.count() == 1 - user_query = User.query.first() + stmt = select(func.count(User.id)) + assert db.session.execute(stmt).scalar() == 1 + stmt = select(User) + user_query = db.session.execute(stmt).scalars().first() assert user_query.email_address == email assert user_query.id == user.id assert user_query.mobile_number == expected_phone_number @@ -68,7 +86,8 @@ def test_get_all_users(notify_db_session): create_user(email="1@test.com") create_user(email="2@test.com") - assert User.query.count() == 2 + stmt = select(func.count(User.id)) + assert db.session.execute(stmt).scalar() == 2 assert len(get_user_by_id()) == 2 @@ -89,9 +108,10 @@ def test_get_user_invalid_id(notify_db_session): def test_delete_users(sample_user): - assert User.query.count() == 1 + stmt = select(func.count(User.id)) + assert db.session.execute(stmt).scalar() == 1 delete_model_user(sample_user) - assert User.query.count() == 0 + assert db.session.execute(stmt).scalar() == 0 def test_increment_failed_login_should_increment_failed_logins(sample_user): @@ -127,9 +147,10 @@ def test_get_user_by_email_is_case_insensitive(sample_user): def test_should_delete_all_verification_codes_more_than_one_day_old(sample_user): make_verify_code(sample_user, age=timedelta(hours=24), code="54321") make_verify_code(sample_user, age=timedelta(hours=24), code="54321") - assert VerifyCode.query.count() == 2 + stmt = select(func.count(VerifyCode.id)) + assert db.session.execute(stmt).scalar() == 2 delete_codes_older_created_more_than_a_day_ago() - assert VerifyCode.query.count() == 0 + assert db.session.execute(stmt).scalar() == 0 def test_should_not_delete_verification_codes_less_than_one_day_old(sample_user): @@ -137,10 +158,11 @@ def test_should_not_delete_verification_codes_less_than_one_day_old(sample_user) sample_user, age=timedelta(hours=23, minutes=59, seconds=59), code="12345" ) make_verify_code(sample_user, age=timedelta(hours=24), code="54321") - - assert VerifyCode.query.count() == 2 + stmt = select(func.count(VerifyCode.id)) + assert db.session.execute(stmt).scalar() == 2 delete_codes_older_created_more_than_a_day_ago() - assert VerifyCode.query.one()._code == "12345" + stmt = select(VerifyCode) + assert db.session.execute(stmt).scalars().one()._code == "12345" def make_verify_code(user, age=None, expiry_age=None, code="12335", code_used=False): diff --git a/tests/app/delivery/test_send_to_providers.py b/tests/app/delivery/test_send_to_providers.py index fbea9a2f7..20b0f7186 100644 --- a/tests/app/delivery/test_send_to_providers.py +++ b/tests/app/delivery/test_send_to_providers.py @@ -1,308 +1,1009 @@ import json -import os -from contextlib import suppress -from urllib import parse +from collections import namedtuple +from unittest.mock import ANY -from cachetools import TTLCache, cached +import pytest from flask import current_app - -from app import ( - aws_pinpoint_client, - create_uuid, - db, - notification_provider_clients, - redis_store, +from requests import HTTPError + +import app +from app import aws_sns_client, notification_provider_clients +from app.cloudfoundry_config import cloud_config +from app.dao import notifications_dao +from app.dao.provider_details_dao import get_provider_details_by_identifier +from app.delivery import send_to_providers +from app.delivery.send_to_providers import ( + _experimentally_validate_phone_numbers, + get_html_email_options, + get_logo_url, ) -from app.aws.s3 import get_personalisation_from_s3, get_phone_number_from_s3 -from app.celery.test_key_tasks import send_email_response, send_sms_response -from app.dao.email_branding_dao import dao_get_email_branding_by_id -from app.dao.notifications_dao import dao_update_notification -from app.dao.provider_details_dao import get_provider_details_by_notification_type -from app.dao.service_sms_sender_dao import dao_get_sms_senders_by_service_id from app.enums import BrandType, KeyType, NotificationStatus, NotificationType from app.exceptions import NotificationTechnicalFailureException -from app.serialised_models import SerialisedService, SerialisedTemplate -from app.utils import hilite, utc_now -from notifications_utils.template import ( - HTMLEmailTemplate, - PlainTextEmailTemplate, - SMSMessageTemplate, +from app.models import EmailBranding, Notification +from app.serialised_models import SerialisedService +from app.utils import utc_now +from tests.app.db import ( + create_email_branding, + create_notification, + create_reply_to_email, + create_service, + create_service_sms_sender, + create_service_with_defined_sms_sender, + create_template, ) -def send_sms_to_provider(notification): - """Final step in the message send flow. - - Get data for recipient, template, - notification and send it to sns. - """ - # we no longer store the personalisation in the db, - # need to retrieve from s3 before generating content - # However, we are still sending the initial verify code through personalisation - # so if there is some value there, don't overwrite it - if not notification.personalisation: - personalisation = get_personalisation_from_s3( - notification.service_id, - notification.job_id, - notification.job_row_number, - ) - notification.personalisation = personalisation - - service = SerialisedService.from_id(notification.service_id) - message_id = None - if not service.active: - technical_failure(notification=notification) - return - - if notification.status == NotificationStatus.CREATED: - # We get the provider here (which is only aws sns) - provider = provider_to_use(NotificationType.SMS, notification.international) - if not provider: - technical_failure(notification=notification) - return - - template_model = SerialisedTemplate.from_id_and_service_id( - template_id=notification.template_id, - service_id=service.id, - version=notification.template_version, - ) - - template = SMSMessageTemplate( - template_model.__dict__, - values=notification.personalisation, - prefix=service.name, - show_prefix=service.prefix_sms, - ) - if notification.key_type == KeyType.TEST: - update_notification_to_sending(notification, provider) - send_sms_response(provider.name, str(notification.id)) - - else: - try: - # End DB session here so that we don't have a connection stuck open waiting on the call - # to one of the SMS providers - # We don't want to tie our DB connections being open to the performance of our SMS - # providers as a slow down of our providers can cause us to run out of DB connections - # Therefore we pull all the data from our DB models into `send_sms_kwargs`now before - # closing the session (as otherwise it would be reopened immediately) - - # We start by trying to get the phone number from a job in s3. If we fail, we assume - # the phone number is for the verification code on login, which is not a job. - recipient = None - # It is our 2facode, maybe - recipient = _get_verify_code(notification) - - if recipient is None: - recipient = get_phone_number_from_s3( - notification.service_id, - notification.job_id, - notification.job_row_number, - ) - - # TODO This is temporary to test the capability of validating phone numbers - # The future home of the validation is TBD - if "+" not in recipient: - recipient_lookup = f"+{recipient}" - else: - recipient_lookup = recipient - if recipient_lookup in current_app.config[ - "SIMULATED_SMS_NUMBERS" - ] and os.getenv("NOTIFY_ENVIRONMENT") in ["development", "test"]: - current_app.logger.info(hilite("#validate-phone-number fired")) - aws_pinpoint_client.validate_phone_number("01", recipient) - else: - current_app.logger.info(hilite("#validate-phone-number not fired")) - - sender_numbers = get_sender_numbers(notification) - if notification.reply_to_text not in sender_numbers: - raise ValueError( - f"{notification.reply_to_text} not in {sender_numbers} #notify-admin-1701" - ) - - send_sms_kwargs = { - "to": recipient, - "content": str(template), - "reference": str(notification.id), - "sender": notification.reply_to_text, - "international": notification.international, - } - db.session.close() # no commit needed as no changes to objects have been made above - - message_id = provider.send_sms(**send_sms_kwargs) - current_app.logger.info(f"got message_id {message_id}") - except Exception as e: - n = notification - msg = f"FAILED send to sms, job_id: {n.job_id} row_number {n.job_row_number} message_id {message_id}" - current_app.logger.exception(hilite(msg)) - - notification.billable_units = template.fragment_count - dao_update_notification(notification) - raise e - else: - # Here we map the job_id and row number to the aws message_id - n = notification - msg = f"Send to aws for job_id {n.job_id} row_number {n.job_row_number} message_id {message_id}" - current_app.logger.info(hilite(msg)) - notification.billable_units = template.fragment_count - update_notification_to_sending(notification, provider) - return message_id - - -def _get_verify_code(notification): - key = f"2facode-{notification.id}".replace(" ", "") - recipient = redis_store.get(key) - with suppress(AttributeError): - recipient = recipient.decode("utf-8") - return recipient - - -def get_sender_numbers(notification): - possible_senders = dao_get_sms_senders_by_service_id(notification.service_id) - sender_numbers = [] - for possible_sender in possible_senders: - sender_numbers.append(possible_sender.sms_sender) - return sender_numbers - - -def send_email_to_provider(notification): - # Someone needs an email, possibly new registration - recipient = redis_store.get(f"email-address-{notification.id}") - recipient = recipient.decode("utf-8") - personalisation = redis_store.get(f"email-personalisation-{notification.id}") - if personalisation: - personalisation = personalisation.decode("utf-8") - notification.personalisation = json.loads(personalisation) - - service = SerialisedService.from_id(notification.service_id) - if not service.active: - technical_failure(notification=notification) - return - if notification.status == NotificationStatus.CREATED: - provider = provider_to_use(NotificationType.EMAIL, False) - template_dict = SerialisedTemplate.from_id_and_service_id( - template_id=notification.template_id, - service_id=service.id, - version=notification.template_version, - ).__dict__ - - html_email = HTMLEmailTemplate( - template_dict, - values=notification.personalisation, - **get_html_email_options(service), - ) - - plain_text_email = PlainTextEmailTemplate( - template_dict, values=notification.personalisation - ) - - if notification.key_type == KeyType.TEST: - notification.reference = str(create_uuid()) - update_notification_to_sending(notification, provider) - send_email_response(notification.reference, recipient) - else: - from_address = ( - f'"{service.name}" <{service.email_from}@' - f'{current_app.config["NOTIFY_EMAIL_DOMAIN"]}>' - ) - - reference = provider.send_email( - from_address, - recipient, - plain_text_email.subject, - body=str(plain_text_email), - html_body=str(html_email), - reply_to_address=notification.reply_to_text, - ) - notification.reference = reference - update_notification_to_sending(notification, provider) - - -def update_notification_to_sending(notification, provider): - notification.sent_at = utc_now() - notification.sent_by = provider.name - if notification.status not in NotificationStatus.completed_types(): - notification.status = NotificationStatus.SENDING - - dao_update_notification(notification) - - -provider_cache = TTLCache(maxsize=8, ttl=10) - - -@cached(cache=provider_cache) -def provider_to_use(notification_type, international=True): - active_providers = [ - p - for p in get_provider_details_by_notification_type( - notification_type, international - ) - if p.active - ] +def setup_function(_function): + # pytest will run this function before each test. It makes sure the + # state of the cache is not shared between tests. + send_to_providers.provider_cache.clear() + + +@pytest.mark.parametrize( + "international_provider_priority", + ( + # Since there’s only one international provider it should always + # be used, no matter what its priority is set to + 0, + 50, + 100, + ), +) +def test_provider_to_use_should_only_return_sns_for_international( + mocker, + notify_db_session, + international_provider_priority, +): + sns = get_provider_details_by_identifier("sns") + sns.priority = international_provider_priority + + ret = send_to_providers.provider_to_use(NotificationType.SMS, international=True) + + assert ret.name == "sns" + + +def test_provider_to_use_raises_if_no_active_providers( + mocker, restore_provider_details +): + sns = get_provider_details_by_identifier("sns") + sns.active = False + + # flake8 doesn't like raises with a generic exception + try: + send_to_providers.provider_to_use(NotificationType.SMS) + assert 1 == 0 + except Exception: + assert 1 == 1 + + +def test_should_send_personalised_template_to_correct_sms_provider_and_persist( + sample_sms_template_with_html, mocker +): + + mocker.patch("app.delivery.send_to_providers._get_verify_code", return_value=None) + db_notification = create_notification( + template=sample_sms_template_with_html, + personalisation={}, + status=NotificationStatus.CREATED, + reply_to_text=sample_sms_template_with_html.service.get_default_sms_sender(), + ) + + mocker.patch("app.aws_sns_client.send_sms") + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "2028675309" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"name": "Jo"} + + send_to_providers.send_sms_to_provider(db_notification) + + aws_sns_client.send_sms.assert_called_once_with( + to="2028675309", + content="Sample service: Hello Jo\nHere is some HTML & entities", + reference=str(db_notification.id), + sender=current_app.config["FROM_NUMBER"], + international=False, + ) + + notification = Notification.query.filter_by(id=db_notification.id).one() + + assert notification.status == NotificationStatus.SENDING + assert notification.sent_at <= utc_now() + assert notification.sent_by == "sns" + assert notification.billable_units == 1 + assert notification.personalisation == {"name": "Jo"} + + +def test_should_send_personalised_template_to_correct_email_provider_and_persist( + sample_email_template_with_html, mocker +): + + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + utf8_encoded_email = "jo.smith@example.com".encode("utf-8") + mock_redis.get.return_value = utf8_encoded_email + email = utf8_encoded_email + personalisation = { + "name": "Jo", + } + personalisation = json.dumps(personalisation) + personalisation = personalisation.encode("utf-8") + mock_redis.get.side_effect = [email, personalisation] + db_notification = create_notification( + template=sample_email_template_with_html, + ) + db_notification.personalisation = {"name": "Jo"} + mocker.patch("app.aws_ses_client.send_email", return_value="reference") + send_to_providers.send_email_to_provider(db_notification) + app.aws_ses_client.send_email.assert_called_once_with( + f'"Sample service" ', + "jo.smith@example.com", + "Jo some HTML", + body="Hello Jo\nThis is an email from GOV.\u200bUK with some HTML\n", + html_body=ANY, + reply_to_address=None, + ) + + assert " version_on_notification + + send_to_providers.send_sms_to_provider(db_notification) + + aws_sns_client.send_sms.assert_called_once_with( + to="2028675309", + content="Sample service: This is a template:\nwith a newline", + reference=str(db_notification.id), + sender=current_app.config["FROM_NUMBER"], + international=False, + ) + + t = dao_get_template_by_id(expected_template_id) + + persisted_notification = notifications_dao.get_notification_by_id( + db_notification.id + ) + assert persisted_notification.to == db_notification.to + assert persisted_notification.template_id == expected_template_id + assert persisted_notification.template_version == version_on_notification + assert persisted_notification.template_version != t.version + assert persisted_notification.status == NotificationStatus.SENDING + + +def test_should_have_sending_status_if_fake_callback_function_fails( + sample_notification, mocker +): + mocker.patch( + "app.delivery.send_to_providers.send_sms_response", + side_effect=HTTPError, + ) + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "2028675309" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + sample_notification.key_type = KeyType.TEST + with pytest.raises(HTTPError): + send_to_providers.send_sms_to_provider(sample_notification) + assert sample_notification.status == NotificationStatus.SENDING + assert sample_notification.sent_by == "sns" + + +def test_should_not_send_to_provider_when_status_is_not_created( + sample_template, mocker +): + notification = create_notification( + template=sample_template, + status=NotificationStatus.SENDING, + ) + mocker.patch("app.aws_sns_client.send_sms") + response_mock = mocker.patch("app.delivery.send_to_providers.send_sms_response") + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "2028675309" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider(notification) + + app.aws_sns_client.send_sms.assert_not_called() + response_mock.assert_not_called() -def get_html_email_options(service): - if service.email_branding is None: - return { - "govuk_banner": True, - "brand_banner": False, - } - if isinstance(service, SerialisedService): - branding = dao_get_email_branding_by_id(service.email_branding) +def test_should_send_sms_with_downgraded_content(notify_db_session, mocker): + # é, o, and u are in GSM. + # ī, grapes, tabs, zero width space and ellipsis are not + # ó isn't in GSM, but it is in the welsh alphabet so will still be sent + + mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) + mocker.patch( + "app.delivery.send_to_providers.get_sender_numbers", return_value=["testing"] + ) + msg = "a é ī o u 🍇 foo\tbar\u200bbaz((misc))…" + placeholder = "∆∆∆abc" + gsm_message = "?ódz Housing Service: a é i o u ? foo barbaz???abc..." + service = create_service(service_name="Łódź Housing Service") + template = create_template(service, content=msg) + db_notification = create_notification( + template=template, + ) + db_notification.personalisation = {"misc": placeholder} + db_notification.reply_to_text = "testing" + + mocker.patch("app.aws_sns_client.send_sms") + + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"misc": placeholder} + + send_to_providers.send_sms_to_provider(db_notification) + + aws_sns_client.send_sms.assert_called_once_with( + to=ANY, content=gsm_message, reference=ANY, sender=ANY, international=False + ) + + +def test_send_sms_should_use_service_sms_sender( + sample_service, sample_template, mocker +): + + mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) + mocker.patch("app.aws_sns_client.send_sms") + + sms_sender = create_service_sms_sender( + service=sample_service, sms_sender="123456", is_default=False + ) + db_notification = create_notification( + template=sample_template, reply_to_text=sms_sender.sms_sender + ) + expected_sender_name = sms_sender.sms_sender + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider( + db_notification, + ) + + app.aws_sns_client.send_sms.assert_called_once_with( + to=ANY, + content=ANY, + reference=ANY, + sender=expected_sender_name, + international=False, + ) + + +def test_send_email_to_provider_should_not_send_to_provider_when_status_is_not_created( + sample_email_template, mocker +): + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + mock_redis.get.return_value = "test@example.com".encode("utf-8") + + notification = create_notification( + template=sample_email_template, status=NotificationStatus.SENDING + ) + mocker.patch("app.aws_ses_client.send_email") + mocker.patch("app.delivery.send_to_providers.send_email_response") + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + send_to_providers.send_sms_to_provider(notification) + app.aws_ses_client.send_email.assert_not_called() + app.delivery.send_to_providers.send_email_response.assert_not_called() + + +def test_send_email_should_use_service_reply_to_email( + sample_service, sample_email_template, mocker +): + mocker.patch("app.aws_ses_client.send_email", return_value="reference") + + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + mock_redis.get.return_value = "test@example.com".encode("utf-8") + + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + email = "foo@bar.com".encode("utf-8") + personalisation = {} + + personalisation = json.dumps(personalisation) + personalisation = personalisation.encode("utf-8") + mock_redis.get.side_effect = [email, personalisation] + + db_notification = create_notification( + template=sample_email_template, reply_to_text="foo@bar.com" + ) + create_reply_to_email(service=sample_service, email_address="foo@bar.com") + + send_to_providers.send_email_to_provider(db_notification) + + app.aws_ses_client.send_email.assert_called_once_with( + ANY, + ANY, + ANY, + body=ANY, + html_body=ANY, + reply_to_address="foo@bar.com", + ) + + +def test_get_html_email_renderer_should_return_for_normal_service(sample_service): + options = send_to_providers.get_html_email_options(sample_service) + assert options["govuk_banner"] is True + assert "brand_colour" not in options.keys() + assert "brand_logo" not in options.keys() + assert "brand_text" not in options.keys() + assert "brand_name" not in options.keys() + + +@pytest.mark.parametrize( + "branding_type, govuk_banner", + [(BrandType.ORG, False), (BrandType.BOTH, True), (BrandType.ORG_BANNER, False)], +) +def test_get_html_email_renderer_with_branding_details( + branding_type, govuk_banner, notify_db_session, sample_service +): + email_branding = EmailBranding( + brand_type=branding_type, + colour="#000000", + logo="justice-league.png", + name="Justice League", + text="League of Justice", + ) + sample_service.email_branding = email_branding + notify_db_session.add_all([sample_service, email_branding]) + notify_db_session.commit() + + options = send_to_providers.get_html_email_options(sample_service) + + assert options["govuk_banner"] == govuk_banner + assert options["brand_colour"] == "#000000" + assert options["brand_text"] == "League of Justice" + assert options["brand_name"] == "Justice League" + + if branding_type == BrandType.ORG_BANNER: + assert options["brand_banner"] is True else: - branding = service.email_branding + assert options["brand_banner"] is False + + +def test_get_html_email_renderer_with_branding_details_and_render_govuk_banner_only( + notify_db_session, sample_service +): + sample_service.email_branding = None + notify_db_session.add_all([sample_service]) + notify_db_session.commit() + + options = send_to_providers.get_html_email_options(sample_service) - logo_url = ( - get_logo_url(current_app.config["ADMIN_BASE_URL"], branding.logo) - if branding.logo - else None + assert options == {"govuk_banner": True, "brand_banner": False} + + +def test_get_html_email_renderer_prepends_logo_path(notify_api): + Service = namedtuple("Service", ["email_branding"]) + EmailBranding = namedtuple( + "EmailBranding", + ["brand_type", "colour", "name", "logo", "text"], ) - return { + email_branding = EmailBranding( + brand_type=BrandType.ORG, + colour="#000000", + logo="justice-league.png", + name="Justice League", + text="League of Justice", + ) + service = Service( + email_branding=email_branding, + ) + + renderer = send_to_providers.get_html_email_options(service) + + assert ( + renderer["brand_logo"] == "http://static-logos.notify.tools/justice-league.png" + ) + + +def test_get_html_email_renderer_handles_email_branding_without_logo(notify_api): + Service = namedtuple("Service", ["email_branding"]) + EmailBranding = namedtuple( + "EmailBranding", + ["brand_type", "colour", "name", "logo", "text"], + ) + + email_branding = EmailBranding( + brand_type=BrandType.ORG_BANNER, + colour="#000000", + logo=None, + name="Justice League", + text="League of Justice", + ) + service = Service( + email_branding=email_branding, + ) + + renderer = send_to_providers.get_html_email_options(service) + + assert renderer["govuk_banner"] is False + assert renderer["brand_banner"] is True + assert renderer["brand_logo"] is None + assert renderer["brand_text"] == "League of Justice" + assert renderer["brand_colour"] == "#000000" + assert renderer["brand_name"] == "Justice League" + + +@pytest.mark.parametrize( + "base_url, expected_url", + [ + # don't change localhost to prevent errors when testing locally + ("http://localhost:6012", "http://static-logos.notify.tools/filename.png"), + ( + "https://www.notifications.service.gov.uk", + "https://static-logos.notifications.service.gov.uk/filename.png", + ), + ("https://notify.works", "https://static-logos.notify.works/filename.png"), + ( + "https://staging-notify.works", + "https://static-logos.staging-notify.works/filename.png", + ), + ("https://www.notify.works", "https://static-logos.notify.works/filename.png"), + ( + "https://www.staging-notify.works", + "https://static-logos.staging-notify.works/filename.png", + ), + ], +) +def test_get_logo_url_works_for_different_environments(base_url, expected_url): + logo_file = "filename.png" + + logo_url = send_to_providers.get_logo_url(base_url, logo_file) + + assert logo_url == expected_url + + +@pytest.mark.parametrize( + "starting_status, expected_status", + [ + (NotificationStatus.DELIVERED, NotificationStatus.DELIVERED), + (NotificationStatus.CREATED, NotificationStatus.SENDING), + (NotificationStatus.TECHNICAL_FAILURE, NotificationStatus.TECHNICAL_FAILURE), + ], +) +def test_update_notification_to_sending_does_not_update_status_from_a_final_status( + sample_service, notify_db_session, starting_status, expected_status +): + template = create_template(sample_service) + notification = create_notification(template=template, status=starting_status) + send_to_providers.update_notification_to_sending( + notification, + notification_provider_clients.get_client_by_name_and_type( + "sns", NotificationType.SMS + ), + ) + assert notification.status == expected_status + + +def __update_notification(notification_to_update, research_mode, expected_status): + if research_mode or notification_to_update.key_type == KeyType.TEST: + notification_to_update.status = expected_status + + +@pytest.mark.parametrize( + "research_mode,key_type, billable_units, expected_status", + [ + (True, KeyType.NORMAL, 0, NotificationStatus.DELIVERED), + (False, KeyType.NORMAL, 1, NotificationStatus.SENDING), + (False, KeyType.TEST, 0, NotificationStatus.SENDING), + (True, KeyType.TEST, 0, NotificationStatus.SENDING), + (True, KeyType.TEAM, 0, NotificationStatus.DELIVERED), + (False, KeyType.TEAM, 1, NotificationStatus.SENDING), + ], +) +def test_should_update_billable_units_and_status_according_to_research_mode_and_key_type( + sample_template, mocker, research_mode, key_type, billable_units, expected_status +): + + mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) + mocker.patch( + "app.delivery.send_to_providers.get_sender_numbers", return_value=["testing"] + ) + notification = create_notification( + template=sample_template, + billable_units=0, + status=NotificationStatus.CREATED, + key_type=key_type, + reply_to_text="testing", + ) + mocker.patch("app.aws_sns_client.send_sms") + mocker.patch( + "app.delivery.send_to_providers.send_sms_response", + side_effect=__update_notification(notification, research_mode, expected_status), + ) + + if research_mode: + sample_template.service.research_mode = True + + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + # So we don't treat it as a one off and have to mock other things + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider(notification) + assert notification.billable_units == billable_units + assert notification.status == expected_status + + +def test_should_set_notification_billable_units_and_reduces_provider_priority_if_sending_to_provider_fails( + sample_notification, + mocker, +): + mocker.patch("app.aws_sns_client.send_sms", side_effect=Exception()) + + sample_notification.billable_units = 0 + assert sample_notification.sent_by is None + + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + # flake8 no longer likes raises with a generic exception + try: + send_to_providers.send_sms_to_provider(sample_notification) + assert 1 == 0 + except Exception: + assert 1 == 1 + + assert sample_notification.billable_units == 1 + + +def test_should_send_sms_to_international_providers( + sample_template, sample_user, mocker +): + + mocker.patch("app.delivery.send_to_providers._get_verify_code", return_value=None) + mocker.patch("app.aws_sns_client.send_sms") + + notification_international = create_notification( + template=sample_template, + to_field="+6011-17224412", + personalisation={"name": "Jo"}, + status=NotificationStatus.CREATED, + international=True, + reply_to_text=sample_template.service.get_default_sms_sender(), + normalised_to="601117224412", + ) + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "601117224412" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider(notification_international) + + aws_sns_client.send_sms.assert_called_once_with( + to="601117224412", + content=ANY, + reference=str(notification_international.id), + sender=current_app.config["FROM_NUMBER"], + international=True, + ) + + assert notification_international.status == NotificationStatus.SENDING + assert notification_international.sent_by == "sns" + + +@pytest.mark.parametrize( + "sms_sender, expected_sender, prefix_sms, expected_content", + [ + ("foo", "foo", False, "bar"), + ("foo", "foo", True, "Sample service: bar"), + # if 40604 is actually in DB then treat that as if entered manually + ("40604", "40604", False, "bar"), + # 'testing' is the FROM_NUMBER during unit tests + ("testing", "testing", True, "Sample service: bar"), + ("testing", "testing", False, "bar"), + ], +) +def test_should_handle_sms_sender_and_prefix_message( + mocker, sms_sender, prefix_sms, expected_sender, expected_content, notify_db_session +): + + mocker.patch("app.delivery.send_to_providers.redis_store", return_value=None) + mocker.patch("app.aws_sns_client.send_sms") + service = create_service_with_defined_sms_sender( + sms_sender_value=sms_sender, prefix_sms=prefix_sms + ) + template = create_template(service, content="bar") + notification = create_notification(template, reply_to_text=sms_sender) + + mock_phone = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_phone.return_value = "15555555555" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider(notification) + + aws_sns_client.send_sms.assert_called_once_with( + content=expected_content, + sender=expected_sender, + to=ANY, + reference=ANY, + international=False, + ) + + +def test_send_email_to_provider_uses_reply_to_from_notification( + sample_email_template, mocker +): + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + mock_redis.get.side_effect = [ + "test@example.com".encode("utf-8"), + json.dumps({}).encode("utf-8"), + ] + + mocker.patch("app.aws_ses_client.send_email", return_value="reference") + + db_notification = create_notification( + template=sample_email_template, + reply_to_text="test@test.com", + ) + + send_to_providers.send_email_to_provider(db_notification) + + app.aws_ses_client.send_email.assert_called_once_with( + ANY, + ANY, + ANY, + body=ANY, + html_body=ANY, + reply_to_address="test@test.com", + ) + + +def test_send_sms_to_provider_should_use_normalised_to(mocker, client, sample_template): + + mocker.patch("app.delivery.send_to_providers._get_verify_code", return_value=None) + mocker.patch( + "app.delivery.send_to_providers.get_sender_numbers", return_value=["testing"] + ) + send_mock = mocker.patch("app.aws_sns_client.send_sms") + notification = create_notification( + template=sample_template, + to_field="+12028675309", + normalised_to="2028675309", + reply_to_text="testing", + ) + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "12028675309" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + send_to_providers.send_sms_to_provider(notification) + send_mock.assert_called_once_with( + to="12028675309", + content=ANY, + reference=str(notification.id), + sender=notification.reply_to_text, + international=False, + ) + + +def test_send_email_to_provider_should_user_normalised_to( + mocker, client, sample_email_template +): + send_mock = mocker.patch("app.aws_ses_client.send_email", return_value="reference") + notification = create_notification( + template=sample_email_template, + ) + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + mock_redis.get.return_value = "test@example.com".encode("utf-8") + + mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + mock_redis.get.return_value = "jo.smith@example.com".encode("utf-8") + email = "test@example.com".encode("utf-8") + personalisation = {} + + personalisation = json.dumps(personalisation) + personalisation = personalisation.encode("utf-8") + mock_redis.get.side_effect = [email, personalisation] + + send_to_providers.send_email_to_provider(notification) + send_mock.assert_called_once_with( + ANY, + "test@example.com", + ANY, + body=ANY, + html_body=ANY, + reply_to_address=notification.reply_to_text, + ) + + +def test_send_sms_to_provider_should_return_template_if_found_in_redis( + mocker, client, sample_template +): + + mocker.patch("app.delivery.send_to_providers._get_verify_code", return_value=None) + mocker.patch( + "app.delivery.send_to_providers.get_sender_numbers", return_value=["testing"] + ) + from app.schemas import service_schema, template_schema + + service_dict = service_schema.dump(sample_template.service) + template_dict = template_schema.dump(sample_template) + + mocker.patch( + "app.redis_store.get", + side_effect=[ + json.dumps({"data": service_dict}).encode("utf-8"), + json.dumps({"data": template_dict}).encode("utf-8"), + ], + ) + mock_get_template = mocker.patch( + "app.dao.templates_dao.dao_get_template_by_id_and_service_id" + ) + mock_get_service = mocker.patch("app.dao.services_dao.dao_fetch_service_by_id") + + send_mock = mocker.patch("app.aws_sns_client.send_sms") + notification = create_notification( + template=sample_template, + to_field="+447700900855", + normalised_to="447700900855", + reply_to_text="testing", + ) + + mock_s3 = mocker.patch("app.delivery.send_to_providers.get_phone_number_from_s3") + mock_s3.return_value = "447700900855" + + mock_personalisation = mocker.patch( + "app.delivery.send_to_providers.get_personalisation_from_s3" + ) + mock_personalisation.return_value = {"ignore": "ignore"} + + send_to_providers.send_sms_to_provider(notification) + assert mock_get_template.called is False + assert mock_get_service.called is False + send_mock.assert_called_once_with( + to="447700900855", + content=ANY, + reference=str(notification.id), + sender=notification.reply_to_text, + international=False, + ) + + +def test_send_email_to_provider_should_return_template_if_found_in_redis( + mocker, client, sample_email_template +): + from app.schemas import service_schema, template_schema + + # mock_redis = mocker.patch("app.delivery.send_to_providers.redis_store") + # mock_redis.get.return_value = "jo.smith@example.com".encode("utf-8") + email = "test@example.com".encode("utf-8") + personalisation = { + "name": "Jo", + } + + personalisation = json.dumps(personalisation) + personalisation = personalisation.encode("utf-8") + # mock_redis.get.side_effect = [email, personalisation] + + service_dict = service_schema.dump(sample_email_template.service) + template_dict = template_schema.dump(sample_email_template) + + mocker.patch( + "app.redis_store.get", + side_effect=[ + email, + personalisation, + json.dumps({"data": service_dict}).encode("utf-8"), + json.dumps({"data": template_dict}).encode("utf-8"), + ], + ) + mock_get_template = mocker.patch( + "app.dao.templates_dao.dao_get_template_by_id_and_service_id" + ) + mock_get_service = mocker.patch("app.dao.services_dao.dao_fetch_service_by_id") + send_mock = mocker.patch("app.aws_ses_client.send_email", return_value="reference") + notification = create_notification( + template=sample_email_template, + ) + + send_to_providers.send_email_to_provider(notification) + assert mock_get_template.called is False + assert mock_get_service.called is False + send_mock.assert_called_once_with( + ANY, + "test@example.com", + ANY, + body=ANY, + html_body=ANY, + reply_to_address=notification.reply_to_text, + ) + + +def test_get_html_email_options_return_email_branding_from_serialised_service( + sample_service, +): + branding = create_email_branding() + sample_service.email_branding = branding + service = SerialisedService.from_id(sample_service.id) + email_options = get_html_email_options(service) + assert email_options is not None + assert email_options == { "govuk_banner": branding.brand_type == BrandType.BOTH, "brand_banner": branding.brand_type == BrandType.ORG_BANNER, "brand_colour": branding.colour, - "brand_logo": logo_url, + "brand_logo": get_logo_url(current_app.config["ADMIN_BASE_URL"], branding.logo), "brand_text": branding.text, "brand_name": branding.name, } -def technical_failure(notification): - notification.status = NotificationStatus.TECHNICAL_FAILURE - dao_update_notification(notification) - raise NotificationTechnicalFailureException( - f"Send {notification.notification_type} for notification id {notification.id} " - f"to provider is not allowed: service {notification.service_id} is inactive" - ) +def test_get_html_email_options_add_email_branding_from_service(sample_service): + branding = create_email_branding() + sample_service.email_branding = branding + email_options = get_html_email_options(sample_service) + assert email_options is not None + assert email_options == { + "govuk_banner": branding.brand_type == BrandType.BOTH, + "brand_banner": branding.brand_type == BrandType.ORG_BANNER, + "brand_colour": branding.colour, + "brand_logo": get_logo_url(current_app.config["ADMIN_BASE_URL"], branding.logo), + "brand_text": branding.text, + "brand_name": branding.name, + } + + +@pytest.mark.parametrize( + ("recipient", "expected_invoke"), + [ + ("15555555555", False), + ], +) +def test_experimentally_validate_phone_numbers(recipient, expected_invoke, mocker): + mock_pinpoint = mocker.patch("app.delivery.send_to_providers.aws_pinpoint_client") + _experimentally_validate_phone_numbers(recipient) + if expected_invoke: + mock_pinpoint.phone_number_validate.assert_called_once_with("foo") + else: + mock_pinpoint.phone_number_validate.assert_not_called() diff --git a/tests/app/service/test_rest.py b/tests/app/service/test_rest.py index fec71cf82..ecec87ec1 100644 --- a/tests/app/service/test_rest.py +++ b/tests/app/service/test_rest.py @@ -1959,6 +1959,84 @@ def test_get_all_notifications_for_service_including_ones_made_by_jobs( assert response.status_code == 200 +def test_get_monthly_notification_stats_by_user( + client, + sample_service, + sample_user, + mocker, +): + mock_s3 = mocker.patch("app.service.rest.get_phone_number_from_s3") + mock_s3.return_value = "" + + mock_s3 = mocker.patch("app.service.rest.get_personalisation_from_s3") + mock_s3.return_value = {} + + auth_header = create_admin_authorization_header() + + response = client.get( + path=( + f"/service/{sample_service.id}/notifications/{sample_user.id}/monthly?year=2024" + ), + headers=[auth_header], + ) + + resp = json.loads(response.get_data(as_text=True)) + print(f"RESP is {resp}") + # TODO This test could be a little more complete + assert response.status_code == 200 + + +def test_get_single_month_notification_stats_by_user( + client, + sample_service, + sample_user, + mocker, +): + mock_s3 = mocker.patch("app.service.rest.get_phone_number_from_s3") + mock_s3.return_value = "" + + mock_s3 = mocker.patch("app.service.rest.get_personalisation_from_s3") + mock_s3.return_value = {} + + auth_header = create_admin_authorization_header() + + response = client.get( + path=( + f"/service/{sample_service.id}/notifications/{sample_user.id}/month?year=2024&month=07" + ), + headers=[auth_header], + ) + + resp = json.loads(response.get_data(as_text=True)) + print(f"RESP is {resp}") + # TODO This test could be a little more complete + assert response.status_code == 200 + + +def test_get_single_month_notification_stats_for_service( + client, + sample_service, + mocker, +): + mock_s3 = mocker.patch("app.service.rest.get_phone_number_from_s3") + mock_s3.return_value = "" + + mock_s3 = mocker.patch("app.service.rest.get_personalisation_from_s3") + mock_s3.return_value = {} + + auth_header = create_admin_authorization_header() + + response = client.get( + path=(f"/service/{sample_service.id}/notifications/month?year=2024&month=07"), + headers=[auth_header], + ) + + resp = json.loads(response.get_data(as_text=True)) + print(f"RESP is {resp}") + # TODO This test could be a little more complete + assert response.status_code == 200 + + def test_get_only_api_created_notifications_for_service( admin_request, sample_job, diff --git a/tests/app/service_invite/test_service_invite_rest.py b/tests/app/service_invite/test_service_invite_rest.py index 07d0b4c23..5cea786f5 100644 --- a/tests/app/service_invite/test_service_invite_rest.py +++ b/tests/app/service_invite/test_service_invite_rest.py @@ -45,6 +45,7 @@ def test_create_invited_user( permissions="send_messages,manage_service,manage_api_keys", auth_type=AuthType.EMAIL, folder_permissions=["folder_1", "folder_2", "folder_3"], + nonce="FakeNonce", **extra_args, ) @@ -108,6 +109,7 @@ def test_create_invited_user_without_auth_type( "from_user": str(invite_from.id), "permissions": "send_messages,manage_service,manage_api_keys", "folder_permissions": [], + "nonce": "FakeNonce", } json_resp = admin_request.post( @@ -131,6 +133,7 @@ def test_create_invited_user_invalid_email(client, sample_service, mocker, fake_ "from_user": str(invite_from.id), "permissions": "send_messages,manage_service,manage_api_keys", "folder_permissions": [fake_uuid, fake_uuid], + "nonce": "FakeNonce", } data = json.dumps(data)