diff --git a/app/dao/annual_billing_dao.py b/app/dao/annual_billing_dao.py index 0e4d3b96b..306a2dd86 100644 --- a/app/dao/annual_billing_dao.py +++ b/app/dao/annual_billing_dao.py @@ -1,4 +1,5 @@ from flask import current_app +from sqlalchemy import select, update from app import db from app.dao.dao_utils import autocommit @@ -26,42 +27,51 @@ def dao_create_or_update_annual_billing_for_year( def dao_get_annual_billing(service_id): - return ( - AnnualBilling.query.filter_by( + stmt = ( + select(AnnualBilling) + .filter_by( service_id=service_id, ) .order_by(AnnualBilling.financial_year_start) - .all() ) + return db.session.execute(stmt).scalars().all() @autocommit def dao_update_annual_billing_for_future_years( service_id, free_sms_fragment_limit, financial_year_start ): - AnnualBilling.query.filter( - AnnualBilling.service_id == service_id, - AnnualBilling.financial_year_start > financial_year_start, - ).update({"free_sms_fragment_limit": free_sms_fragment_limit}) + stmt = ( + update(AnnualBilling) + .filter( + AnnualBilling.service_id == service_id, + AnnualBilling.financial_year_start > financial_year_start, + ) + .values({"free_sms_fragment_limit": free_sms_fragment_limit}) + ) + db.session.execute(stmt) + db.session.commit() def dao_get_free_sms_fragment_limit_for_year(service_id, financial_year_start=None): if not financial_year_start: financial_year_start = get_current_calendar_year_start_year() - return AnnualBilling.query.filter_by( + stmt = select(AnnualBilling).filter_by( service_id=service_id, financial_year_start=financial_year_start - ).first() + ) + return db.session.execute(stmt).scalars().first() def dao_get_all_free_sms_fragment_limit(service_id): - return ( - AnnualBilling.query.filter_by( + stmt = ( + select(AnnualBilling) + .filter_by( service_id=service_id, ) .order_by(AnnualBilling.financial_year_start) - .all() ) + return db.session.execute(stmt).scalars().all() def set_default_free_allowance_for_service(service, year_start=None): diff --git a/app/dao/fact_billing_dao.py b/app/dao/fact_billing_dao.py index 14d82835b..132f62bf2 100644 --- a/app/dao/fact_billing_dao.py +++ b/app/dao/fact_billing_dao.py @@ -1,7 +1,7 @@ from datetime import date, timedelta from flask import current_app -from sqlalchemy import Date, Integer, and_, desc, func, union +from sqlalchemy import Date, Integer, and_, delete, desc, func, select, union from sqlalchemy.dialects.postgresql import insert from sqlalchemy.sql.expression import case, literal @@ -31,7 +31,7 @@ def fetch_sms_free_allowance_remainder_until_date(end_date): ) query = ( - db.session.query( + select( AnnualBilling.service_id.label("service_id"), AnnualBilling.free_sms_fragment_limit, billable_units.label("billable_units"), @@ -40,6 +40,7 @@ def fetch_sms_free_allowance_remainder_until_date(end_date): 0, ).label("sms_remainder"), ) + .select_from(AnnualBilling) .outerjoin( # if there are no ft_billing rows for a service we still want to return the annual billing so we can use the # free_sms_fragment_limit) @@ -87,7 +88,7 @@ def fetch_sms_billing_for_all_services(start_date, end_date): sms_cost = chargeable_sms * FactBilling.rate query = ( - db.session.query( + select( Organization.name.label("organization_name"), Organization.id.label("organization_id"), Service.name.label("service_name"), @@ -126,7 +127,7 @@ def fetch_sms_billing_for_all_services(start_date, end_date): .order_by(Organization.name, Service.name) ) - return query.all() + return db.session.execute(query).all() def fetch_billing_totals_for_year(service_id, year): @@ -146,36 +147,29 @@ def fetch_billing_totals_for_year(service_id, year): a rate multiplier. Each subquery returns the same set of columns, which we pick from here before the big union. """ - return ( - db.session.query( - union( - *[ - db.session.query( - query.c.notification_type.label("notification_type"), - query.c.rate.label("rate"), - func.sum(query.c.notifications_sent).label( - "notifications_sent" - ), - func.sum(query.c.chargeable_units).label("chargeable_units"), - func.sum(query.c.cost).label("cost"), - func.sum(query.c.free_allowance_used).label( - "free_allowance_used" - ), - func.sum(query.c.charged_units).label("charged_units"), - ).group_by(query.c.rate, query.c.notification_type) - for query in [ - query_service_sms_usage_for_year(service_id, year).subquery(), - query_service_email_usage_for_year(service_id, year).subquery(), - ] + stmt = select( + union( + *[ + select( + query.c.notification_type.label("notification_type"), + query.c.rate.label("rate"), + func.sum(query.c.notifications_sent).label("notifications_sent"), + func.sum(query.c.chargeable_units).label("chargeable_units"), + func.sum(query.c.cost).label("cost"), + func.sum(query.c.free_allowance_used).label("free_allowance_used"), + func.sum(query.c.charged_units).label("charged_units"), + ).group_by(query.c.rate, query.c.notification_type) + for query in [ + query_service_sms_usage_for_year(service_id, year).subquery(), + query_service_email_usage_for_year(service_id, year).subquery(), ] - ).subquery() - ) - .order_by( - "notification_type", - "rate", - ) - .all() + ] + ).subquery() + ).order_by( + "notification_type", + "rate", ) + return db.session.execute(stmt).all() def fetch_monthly_billing_for_year(service_id, year): @@ -208,63 +202,60 @@ def fetch_monthly_billing_for_year(service_id, year): for d in data: update_fact_billing(data=d, process_day=today) - return ( - db.session.query( - union( - *[ - db.session.query( - query.c.rate.label("rate"), - query.c.notification_type.label("notification_type"), - func.date_trunc("month", query.c.local_date) - .cast(Date) - .label("month"), - func.sum(query.c.notifications_sent).label( - "notifications_sent" - ), - func.sum(query.c.chargeable_units).label("chargeable_units"), - func.sum(query.c.cost).label("cost"), - func.sum(query.c.free_allowance_used).label( - "free_allowance_used" - ), - func.sum(query.c.charged_units).label("charged_units"), - ).group_by( - query.c.rate, - query.c.notification_type, - "month", - ) - for query in [ - query_service_sms_usage_for_year(service_id, year).subquery(), - query_service_email_usage_for_year(service_id, year).subquery(), - ] + stmt = select( + union( + *[ + select( + query.c.rate.label("rate"), + query.c.notification_type.label("notification_type"), + func.date_trunc("month", query.c.local_date) + .cast(Date) + .label("month"), + func.sum(query.c.notifications_sent).label("notifications_sent"), + func.sum(query.c.chargeable_units).label("chargeable_units"), + func.sum(query.c.cost).label("cost"), + func.sum(query.c.free_allowance_used).label("free_allowance_used"), + func.sum(query.c.charged_units).label("charged_units"), + ).group_by( + query.c.rate, + query.c.notification_type, + "month", + ) + for query in [ + query_service_sms_usage_for_year(service_id, year).subquery(), + query_service_email_usage_for_year(service_id, year).subquery(), ] - ).subquery() - ) - .order_by( - "month", - "notification_type", - "rate", - ) - .all() + ] + ).subquery() + ).order_by( + "month", + "notification_type", + "rate", ) + return db.session.execute(stmt).all() def query_service_email_usage_for_year(service_id, year): year_start, year_end = get_calendar_year_dates(year) - return db.session.query( - FactBilling.local_date, - FactBilling.notifications_sent, - FactBilling.billable_units.label("chargeable_units"), - FactBilling.rate, - FactBilling.notification_type, - literal(0).label("cost"), - literal(0).label("free_allowance_used"), - FactBilling.billable_units.label("charged_units"), - ).filter( - FactBilling.service_id == service_id, - FactBilling.local_date >= year_start, - FactBilling.local_date <= year_end, - FactBilling.notification_type == NotificationType.EMAIL, + return ( + select( + FactBilling.local_date, + FactBilling.notifications_sent, + FactBilling.billable_units.label("chargeable_units"), + FactBilling.rate, + FactBilling.notification_type, + literal(0).label("cost"), + literal(0).label("free_allowance_used"), + FactBilling.billable_units.label("charged_units"), + ) + .select_from(FactBilling) + .filter( + FactBilling.service_id == service_id, + FactBilling.local_date >= year_start, + FactBilling.local_date <= year_end, + FactBilling.notification_type == NotificationType.EMAIL, + ) ) @@ -334,9 +325,8 @@ def query_service_sms_usage_for_year(service_id, year): free_allowance_used = func.least( remaining_free_allowance_before_this_row, this_rows_chargeable_units ) - - return ( - db.session.query( + stmt = ( + select( FactBilling.local_date, FactBilling.notifications_sent, this_rows_chargeable_units.label("chargeable_units"), @@ -346,6 +336,7 @@ def query_service_sms_usage_for_year(service_id, year): free_allowance_used.label("free_allowance_used"), charged_units.label("charged_units"), ) + .select_from(FactBilling) .join(AnnualBilling, AnnualBilling.service_id == service_id) .filter( FactBilling.service_id == service_id, @@ -355,6 +346,7 @@ def query_service_sms_usage_for_year(service_id, year): AnnualBilling.financial_year_start == year, ) ) + return stmt def delete_billing_data_for_service_for_day(process_day, service_id): @@ -363,9 +355,12 @@ def delete_billing_data_for_service_for_day(process_day, service_id): Returns how many rows were deleted """ - return FactBilling.query.filter( + stmt = delete(FactBilling).filter( FactBilling.local_date == process_day, FactBilling.service_id == service_id - ).delete() + ) + result = db.session.execute(stmt) + db.session.commit() + return result.rowcount def fetch_billing_data_for_day(process_day, service_id=None, check_permissions=False): @@ -397,7 +392,7 @@ def fetch_billing_data_for_day(process_day, service_id=None, check_permissions=F def _query_for_billing_data(notification_type, start_date, end_date, service): def _email_query(): return ( - db.session.query( + select( NotificationAllTimeView.template_id, literal(service.id).label("service_id"), literal(notification_type).label("notification_type"), @@ -407,6 +402,7 @@ def _email_query(): literal(0).label("billable_units"), func.count().label("notifications_sent"), ) + .select_from(NotificationAllTimeView) .filter( NotificationAllTimeView.status.in_( NotificationStatus.sent_email_types() @@ -429,7 +425,7 @@ def _sms_query(): ).cast(Integer) international = func.coalesce(NotificationAllTimeView.international, False) return ( - db.session.query( + select( NotificationAllTimeView.template_id, literal(service.id).label("service_id"), literal(notification_type).label("notification_type"), @@ -441,6 +437,7 @@ def _sms_query(): ), func.count().label("notifications_sent"), ) + .select_from(NotificationAllTimeView) .filter( NotificationAllTimeView.status.in_( NotificationStatus.billable_sms_types() @@ -465,17 +462,18 @@ def _sms_query(): } query = query_funcs[notification_type]() - return query.all() + return db.session.execute(query).all() def get_rates_for_billing(): - rates = Rate.query.order_by(desc(Rate.valid_from)).all() - return rates + stmt = select(Rate).order_by(desc(Rate.valid_from)) + return db.session.execute(stmt).scalars().all() def get_service_ids_that_need_billing_populated(start_date, end_date): - return ( - db.session.query(NotificationHistory.service_id) + stmt = ( + select(NotificationHistory.service_id) + .select_from(NotificationHistory) .filter( NotificationHistory.created_at >= start_date, NotificationHistory.created_at <= end_date, @@ -485,8 +483,8 @@ def get_service_ids_that_need_billing_populated(start_date, end_date): NotificationHistory.billable_units != 0, ) .distinct() - .all() ) + return db.session.execute(stmt).all() def get_rate(rates, notification_type, date): @@ -560,7 +558,7 @@ def create_billing_record(data, rate, process_day): def fetch_email_usage_for_organization(organization_id, start_date, end_date): query = ( - db.session.query( + select( Service.name.label("service_name"), Service.id.label("service_id"), func.sum(FactBilling.notifications_sent).label("emails_sent"), @@ -583,7 +581,7 @@ def fetch_email_usage_for_organization(organization_id, start_date, end_date): ) .order_by(Service.name) ) - return query.all() + return db.session.execute(query).all() def fetch_sms_billing_for_organization(organization_id, financial_year): @@ -606,7 +604,7 @@ def fetch_sms_billing_for_organization(organization_id, financial_year): sms_cost = func.sum(ft_billing_subquery.c.cost) query = ( - db.session.query( + select( Service.name.label("service_name"), Service.id.label("service_id"), AnnualBilling.free_sms_fragment_limit, @@ -632,7 +630,7 @@ def fetch_sms_billing_for_organization(organization_id, financial_year): .order_by(Service.name) ) - return query.all() + return db.session.execute(query).all() def query_organization_sms_usage_for_year(organization_id, year): @@ -673,7 +671,7 @@ def query_organization_sms_usage_for_year(organization_id, year): ) return ( - db.session.query( + select( Service.id.label("service_id"), FactBilling.local_date, this_rows_chargeable_units.label("chargeable_units"), @@ -748,7 +746,7 @@ def fetch_usage_year_for_organization(organization_id, year): def fetch_billing_details_for_all_services(): billing_details = ( - db.session.query( + select( Service.id.label("service_id"), func.coalesce( Service.purchase_order_number, Organization.purchase_order_number @@ -764,18 +762,18 @@ def fetch_billing_details_for_all_services(): Service.billing_reference, Organization.billing_reference ).label("billing_reference"), ) + .select_from(Service) .outerjoin(Service.organization) - .all() ) - return billing_details + return db.session.execute(billing_details).all() def fetch_daily_volumes_for_platform(start_date, end_date): # query to return the total notifications sent per day for each channel. NB start and end dates are inclusive daily_volume_stats = ( - db.session.query( + select( FactBilling.local_date, func.sum( case( @@ -822,7 +820,7 @@ def fetch_daily_volumes_for_platform(start_date, end_date): ) aggregated_totals = ( - db.session.query( + select( daily_volume_stats.c.local_date.cast(db.Text).label("local_date"), func.sum(daily_volume_stats.c.sms_totals).label("sms_totals"), func.sum(daily_volume_stats.c.sms_fragment_totals).label( @@ -835,17 +833,16 @@ def fetch_daily_volumes_for_platform(start_date, end_date): ) .group_by(daily_volume_stats.c.local_date) .order_by(daily_volume_stats.c.local_date) - .all() ) - return aggregated_totals + return db.session.execute(aggregated_totals).all() def fetch_daily_sms_provider_volumes_for_platform(start_date, end_date): # query to return the total notifications sent per day for each channel. NB start and end dates are inclusive - daily_volume_stats = ( - db.session.query( + stmt = ( + select( FactBilling.local_date, FactBilling.provider, func.sum(FactBilling.notifications_sent).label("sms_totals"), @@ -859,6 +856,7 @@ def fetch_daily_sms_provider_volumes_for_platform(start_date, end_date): * FactBilling.rate ).label("sms_cost"), ) + .select_from(FactBilling) .filter( FactBilling.notification_type == NotificationType.SMS, FactBilling.local_date >= start_date, @@ -872,10 +870,8 @@ def fetch_daily_sms_provider_volumes_for_platform(start_date, end_date): FactBilling.local_date, FactBilling.provider, ) - .all() ) - - return daily_volume_stats + return db.session.execute(stmt).all() def fetch_volumes_by_service(start_date, end_date): @@ -884,7 +880,7 @@ def fetch_volumes_by_service(start_date, end_date): year_end_date = int(end_date.strftime("%Y")) volume_stats = ( - db.session.query( + select( FactBilling.local_date, FactBilling.service_id, func.sum( @@ -915,6 +911,7 @@ def fetch_volumes_by_service(start_date, end_date): ) ).label("email_totals"), ) + .select_from(FactBilling) .filter( FactBilling.local_date >= start_date, FactBilling.local_date <= end_date ) @@ -927,18 +924,18 @@ def fetch_volumes_by_service(start_date, end_date): ) annual_billing = ( - db.session.query( + select( func.max(AnnualBilling.financial_year_start).label("financial_year_start"), AnnualBilling.service_id, AnnualBilling.free_sms_fragment_limit, ) + .select_from(AnnualBilling) .filter(AnnualBilling.financial_year_start <= year_end_date) .group_by(AnnualBilling.service_id, AnnualBilling.free_sms_fragment_limit) .subquery() ) - - results = ( - db.session.query( + stmt = ( + select( Service.name.label("service_name"), Service.id.label("service_id"), Service.organization_id.label("organization_id"), @@ -976,7 +973,7 @@ def fetch_volumes_by_service(start_date, end_date): Organization.name, Service.name, ) - .all() ) + results = db.session.execute(stmt).all() return results diff --git a/app/dao/fact_notification_status_dao.py b/app/dao/fact_notification_status_dao.py index df8e653ee..4b238642e 100644 --- a/app/dao/fact_notification_status_dao.py +++ b/app/dao/fact_notification_status_dao.py @@ -1,6 +1,6 @@ from datetime import timedelta -from sqlalchemy import Date, case, cast, func, select, union_all +from sqlalchemy import Date, case, cast, delete, func, select, union_all from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import aliased from sqlalchemy.sql.expression import extract, literal @@ -33,14 +33,16 @@ def update_fact_notification_status(process_day, notification_type, service_id): end_date = get_midnight_in_utc(process_day + timedelta(days=1)) # delete any existing rows in case some no longer exist e.g. if all messages are sent - FactNotificationStatus.query.filter( + stmt = delete(FactNotificationStatus).filter( FactNotificationStatus.local_date == process_day, FactNotificationStatus.notification_type == notification_type, FactNotificationStatus.service_id == service_id, - ).delete() + ) + db.session.execute(stmt) + db.session.commit() query = ( - db.session.query( + select( literal(process_day).label("process_day"), NotificationAllTimeView.template_id, literal(service_id).label("service_id"), @@ -52,6 +54,7 @@ def update_fact_notification_status(process_day, notification_type, service_id): NotificationAllTimeView.status, func.count().label("notification_count"), ) + .select_from(NotificationAllTimeView) .filter( NotificationAllTimeView.created_at >= start_date, NotificationAllTimeView.created_at < end_date, @@ -86,13 +89,14 @@ def update_fact_notification_status(process_day, notification_type, service_id): def fetch_notification_status_for_service_by_month(start_date, end_date, service_id): - return ( - db.session.query( + stmt = ( + select( func.date_trunc("month", NotificationAllTimeView.created_at).label("month"), NotificationAllTimeView.notification_type, NotificationAllTimeView.status.label("notification_status"), func.count(NotificationAllTimeView.id).label("count"), ) + .select_from(NotificationAllTimeView) .filter( NotificationAllTimeView.service_id == service_id, NotificationAllTimeView.created_at >= start_date, @@ -104,19 +108,20 @@ def fetch_notification_status_for_service_by_month(start_date, end_date, service NotificationAllTimeView.notification_type, NotificationAllTimeView.status, ) - .all() ) + return db.session.execute(stmt).all() def fetch_notification_status_for_service_for_day(fetch_day, service_id): - return ( - db.session.query( + stmt = ( + select( # return current month as a datetime so the data has the same shape as the ft_notification_status query literal(fetch_day.replace(day=1), type_=DateTime).label("month"), Notification.notification_type, Notification.status.label("notification_status"), func.count().label("count"), ) + .select_from(Notification) .filter( Notification.created_at >= get_midnight_in_utc(fetch_day), Notification.created_at @@ -125,8 +130,8 @@ def fetch_notification_status_for_service_for_day(fetch_day, service_id): Notification.key_type != KeyType.TEST, ) .group_by(Notification.notification_type, Notification.status) - .all() ) + return db.session.execute(stmt).all() def fetch_notification_status_for_service_for_today_and_7_previous_days( @@ -246,7 +251,7 @@ def fetch_notification_status_for_service_for_today_and_7_previous_days( def fetch_notification_status_totals_for_all_services(start_date, end_date): stats = ( - db.session.query( + select( FactNotificationStatus.notification_type.cast(db.Text).label( "notification_type" ), @@ -254,6 +259,7 @@ def fetch_notification_status_totals_for_all_services(start_date, end_date): FactNotificationStatus.key_type.cast(db.Text).label("key_type"), func.sum(FactNotificationStatus.notification_count).label("count"), ) + .select_from(FactNotificationStatus) .filter( FactNotificationStatus.local_date >= start_date, FactNotificationStatus.local_date <= end_date, @@ -267,7 +273,7 @@ def fetch_notification_status_totals_for_all_services(start_date, end_date): today = get_midnight_in_utc(utc_now()) if start_date <= utc_now().date() <= end_date: stats_for_today = ( - db.session.query( + select( Notification.notification_type.cast(db.Text).label("notification_type"), Notification.status.cast(db.Text), Notification.key_type.cast(db.Text), @@ -282,7 +288,7 @@ def fetch_notification_status_totals_for_all_services(start_date, end_date): ) all_stats_table = stats.union_all(stats_for_today).subquery() query = ( - db.session.query( + select( all_stats_table.c.notification_type, all_stats_table.c.status, all_stats_table.c.key_type, @@ -297,28 +303,29 @@ def fetch_notification_status_totals_for_all_services(start_date, end_date): ) else: query = stats.order_by(FactNotificationStatus.notification_type) - return query.all() + return db.session.execute(query).all() def fetch_notification_statuses_for_job(job_id): - return ( - db.session.query( + stmt = ( + select( FactNotificationStatus.notification_status.label("status"), func.sum(FactNotificationStatus.notification_count).label("count"), ) + .select_from(FactNotificationStatus) .filter( FactNotificationStatus.job_id == job_id, ) .group_by(FactNotificationStatus.notification_status) - .all() ) + return db.session.execute(stmt).all() def fetch_stats_for_all_services_by_date_range( start_date, end_date, include_from_test_key=True ): stats = ( - db.session.query( + select( FactNotificationStatus.service_id.label("service_id"), Service.name.label("name"), Service.restricted.label("restricted"), @@ -330,6 +337,7 @@ def fetch_stats_for_all_services_by_date_range( FactNotificationStatus.notification_status.cast(db.Text).label("status"), func.sum(FactNotificationStatus.notification_count).label("count"), ) + .select_from(FactNotificationStatus) .filter( FactNotificationStatus.local_date >= start_date, FactNotificationStatus.local_date <= end_date, @@ -354,12 +362,13 @@ def fetch_stats_for_all_services_by_date_range( if start_date <= utc_now().date() <= end_date: today = get_midnight_in_utc(utc_now()) subquery = ( - db.session.query( + select( Notification.notification_type.label("notification_type"), Notification.status.label("status"), Notification.service_id.label("service_id"), func.count(Notification.id).label("count"), ) + .select_from(Notification) .filter(Notification.created_at >= today) .group_by( Notification.notification_type, @@ -371,7 +380,7 @@ def fetch_stats_for_all_services_by_date_range( subquery = subquery.filter(Notification.key_type != KeyType.TEST) subquery = subquery.subquery() - stats_for_today = db.session.query( + stats_for_today = select( Service.id.label("service_id"), Service.name.label("name"), Service.restricted.label("restricted"), @@ -384,7 +393,7 @@ def fetch_stats_for_all_services_by_date_range( all_stats_table = stats.union_all(stats_for_today).subquery() query = ( - db.session.query( + select( all_stats_table.c.service_id, all_stats_table.c.name, all_stats_table.c.restricted, @@ -411,13 +420,13 @@ def fetch_stats_for_all_services_by_date_range( ) else: query = stats - return query.all() + return db.session.execute(query).all() def fetch_monthly_template_usage_for_service(start_date, end_date, service_id): # services_dao.replaces dao_fetch_monthly_historical_usage_by_template_for_service stats = ( - db.session.query( + select( FactNotificationStatus.template_id.label("template_id"), Template.name.label("name"), Template.template_type.label("template_type"), @@ -452,7 +461,7 @@ def fetch_monthly_template_usage_for_service(start_date, end_date, service_id): month = get_month_from_utc_column(Notification.created_at) stats_for_today = ( - db.session.query( + select( Notification.template_id.label("template_id"), Template.name.label("name"), Template.template_type.label("template_type"), @@ -481,7 +490,7 @@ def fetch_monthly_template_usage_for_service(start_date, end_date, service_id): all_stats_table = stats.union_all(stats_for_today).subquery() query = ( - db.session.query( + select( all_stats_table.c.template_id, all_stats_table.c.name, all_stats_table.c.template_type, @@ -502,12 +511,12 @@ def fetch_monthly_template_usage_for_service(start_date, end_date, service_id): ) else: query = stats - return query.all() + return db.session.execute(query).all() def get_total_notifications_for_date_range(start_date, end_date): query = ( - db.session.query( + select( FactNotificationStatus.local_date.label("local_date"), func.sum( case( @@ -541,12 +550,12 @@ def get_total_notifications_for_date_range(start_date, end_date): FactNotificationStatus.local_date >= start_date, FactNotificationStatus.local_date <= end_date, ) - return query.all() + return db.session.execute(query).all() def fetch_monthly_notification_statuses_per_service(start_date, end_date): - return ( - db.session.query( + stmt = ( + select( func.date_trunc("month", FactNotificationStatus.local_date) .cast(Date) .label("date_created"), @@ -639,5 +648,5 @@ def fetch_monthly_notification_statuses_per_service(start_date, end_date): Service.id, FactNotificationStatus.notification_type, ) - .all() ) + return db.session.execute(stmt).all() diff --git a/app/dao/jobs_dao.py b/app/dao/jobs_dao.py index f4914e423..ddec26956 100644 --- a/app/dao/jobs_dao.py +++ b/app/dao/jobs_dao.py @@ -3,9 +3,10 @@ from datetime import timedelta from flask import current_app -from sqlalchemy import and_, asc, desc, func +from sqlalchemy import and_, asc, desc, func, select from app import db +from app.dao.pagination import Pagination from app.enums import JobStatus from app.models import ( FactNotificationStatus, @@ -18,36 +19,33 @@ def dao_get_notification_outcomes_for_job(service_id, job_id): - notification_statuses = ( - db.session.query( - func.count(Notification.status).label("count"), Notification.status - ) + stmt = ( + select(func.count(Notification.status).label("count"), Notification.status) .filter(Notification.service_id == service_id, Notification.job_id == job_id) .group_by(Notification.status) - .all() ) + notification_statuses = db.session.execute(stmt).all() if not notification_statuses: - notification_statuses = ( - db.session.query( - FactNotificationStatus.notification_count.label("count"), - FactNotificationStatus.notification_status.label("status"), - ) - .filter( - FactNotificationStatus.service_id == service_id, - FactNotificationStatus.job_id == job_id, - ) - .all() + stmt = select( + FactNotificationStatus.notification_count.label("count"), + FactNotificationStatus.notification_status.label("status"), + ).filter( + FactNotificationStatus.service_id == service_id, + FactNotificationStatus.job_id == job_id, ) + notification_statuses = db.session.execute(stmt).all() return notification_statuses def dao_get_job_by_service_id_and_job_id(service_id, job_id): - return Job.query.filter_by(service_id=service_id, id=job_id).one() + stmt = select(Job).filter_by(service_id=service_id, id=job_id) + return db.session.execute(stmt).scalars().one() def dao_get_unfinished_jobs(): - return Job.query.filter(Job.processing_finished.is_(None)).all() + stmt = select(Job).filter(Job.processing_finished.is_(None)) + return db.session.execute(stmt).all() def dao_get_jobs_by_service_id( @@ -67,31 +65,40 @@ def dao_get_jobs_by_service_id( query_filter.append(Job.created_at >= midnight_n_days_ago(limit_days)) if statuses is not None and statuses != [""]: query_filter.append(Job.job_status.in_(statuses)) - return ( - Job.query.filter(*query_filter) + + total_items = db.session.execute( + select(func.count()).select_from(Job).filter(*query_filter) + ).scalar_one() + + offset = (page - 1) * page_size + stmt = ( + select(Job) + .filter(*query_filter) .order_by(Job.processing_started.desc(), Job.created_at.desc()) - .paginate(page=page, per_page=page_size) + .limit(page_size) + .offset(offset) ) + items = db.session.execute(stmt).scalars().all() + return Pagination(items, page, page_size, total_items) def dao_get_scheduled_job_stats( service_id, ): - return ( - db.session.query( - func.count(Job.id), - func.min(Job.scheduled_for), - ) - .filter( - Job.service_id == service_id, - Job.job_status == JobStatus.SCHEDULED, - ) - .one() + + stmt = select( + func.count(Job.id), + func.min(Job.scheduled_for), + ).filter( + Job.service_id == service_id, + Job.job_status == JobStatus.SCHEDULED, ) + return db.session.execute(stmt).one() def dao_get_job_by_id(job_id): - return Job.query.filter_by(id=job_id).one() + stmt = select(Job).filter_by(id=job_id) + return db.session.execute(stmt).scalars().one() def dao_archive_job(job): @@ -108,15 +115,16 @@ def dao_set_scheduled_jobs_to_pending(): the transaction so that if the task is run more than once concurrently, one task will block the other select from completing until it commits. """ - jobs = ( - Job.query.filter( + stmt = ( + select(Job) + .filter( Job.job_status == JobStatus.SCHEDULED, Job.scheduled_for < utc_now(), ) .order_by(asc(Job.scheduled_for)) .with_for_update() - .all() ) + jobs = db.session.execute(stmt).scalars().all() for job in jobs: job.job_status = JobStatus.PENDING @@ -128,12 +136,13 @@ def dao_set_scheduled_jobs_to_pending(): def dao_get_future_scheduled_job_by_id_and_service_id(job_id, service_id): - return Job.query.filter( + stmt = select(Job).filter( Job.service_id == service_id, Job.id == job_id, Job.job_status == JobStatus.SCHEDULED, Job.scheduled_for > utc_now(), - ).one() + ) + return db.session.execute(stmt).scalars().one() def dao_create_job(job): @@ -168,16 +177,17 @@ def dao_update_job(job): def dao_get_jobs_older_than_data_retention(notification_types): - flexible_data_retention = ServiceDataRetention.query.filter( + stmt = select(ServiceDataRetention).filter( ServiceDataRetention.notification_type.in_(notification_types) - ).all() + ) + flexible_data_retention = db.session.execute(stmt).scalars().all() jobs = [] today = utc_now().date() for f in flexible_data_retention: end_date = today - timedelta(days=f.days_of_retention) - - jobs.extend( - Job.query.join(Template) + stmt = ( + select(Job) + .join(Template) .filter( func.coalesce(Job.scheduled_for, Job.created_at) < end_date, Job.archived == False, # noqa @@ -185,8 +195,8 @@ def dao_get_jobs_older_than_data_retention(notification_types): Job.service_id == f.service_id, ) .order_by(desc(Job.created_at)) - .all() ) + jobs.extend(db.session.execute(stmt).scalars().all()) # notify-api-1287, make default data retention 7 days, 23 hours end_date = today - timedelta(days=7, hours=23) @@ -196,8 +206,9 @@ def dao_get_jobs_older_than_data_retention(notification_types): for x in flexible_data_retention if x.notification_type == notification_type ] - jobs.extend( - Job.query.join(Template) + stmt = ( + select(Job) + .join(Template) .filter( func.coalesce(Job.scheduled_for, Job.created_at) < end_date, Job.archived == False, # noqa @@ -205,8 +216,8 @@ def dao_get_jobs_older_than_data_retention(notification_types): Job.service_id.notin_(services_with_data_retention), ) .order_by(desc(Job.created_at)) - .all() ) + jobs.extend(db.session.execute(stmt).scalars().all()) return jobs @@ -217,7 +228,7 @@ def find_jobs_with_missing_rows(): ten_minutes_ago = utc_now() - timedelta(minutes=20) yesterday = utc_now() - timedelta(days=1) jobs_with_rows_missing = ( - db.session.query(Job) + select(Job) .filter( Job.job_status == JobStatus.FINISHED, Job.processing_finished < ten_minutes_ago, @@ -228,16 +239,16 @@ def find_jobs_with_missing_rows(): .having(func.count(Notification.id) != Job.notification_count) ) - return jobs_with_rows_missing.all() + return db.session.execute(jobs_with_rows_missing).scalars().all() def find_missing_row_for_job(job_id, job_size): - expected_row_numbers = db.session.query( + expected_row_numbers = select( func.generate_series(0, job_size - 1).label("row") ).subquery() query = ( - db.session.query( + select( Notification.job_row_number, expected_row_numbers.c.row.label("missing_row") ) .outerjoin( @@ -249,4 +260,4 @@ def find_missing_row_for_job(job_id, job_size): ) .filter(Notification.job_row_number == None) # noqa ) - return query.all() + return db.session.execute(query).all() diff --git a/app/dao/pagination.py b/app/dao/pagination.py new file mode 100644 index 000000000..cf6d8d4bd --- /dev/null +++ b/app/dao/pagination.py @@ -0,0 +1,15 @@ +class Pagination: + def __init__(self, items, page, per_page, total): + self.items = items + self.page = page + self.per_page = per_page + self.total = total + self.pages = (total + per_page - 1) // per_page + self.prev_num = page - 1 if page > 1 else None + self.next_num = page + 1 if page < self.pages else None + + def has_next(self): + return self.page < self.pages + + def has_prev(self): + return self.page > 1 diff --git a/tests/app/dao/test_fact_billing_dao.py b/tests/app/dao/test_fact_billing_dao.py index 30f2cd1c3..e1331dfe5 100644 --- a/tests/app/dao/test_fact_billing_dao.py +++ b/tests/app/dao/test_fact_billing_dao.py @@ -3,6 +3,7 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select from app import db from app.dao.fact_billing_dao import ( @@ -614,7 +615,8 @@ def test_delete_billing_data(notify_db_session): delete_billing_data_for_service_for_day("2018-01-01", service_1.id) - current_rows = FactBilling.query.all() + stmt = select(FactBilling) + current_rows = db.session.execute(stmt).scalars().all() assert sorted(x.billable_units for x in current_rows) == sorted( [other_day.billable_units, other_service.billable_units] ) @@ -671,7 +673,8 @@ def test_fetch_sms_free_allowance_remainder_until_date_with_two_services( rate=0.11, ) - results = fetch_sms_free_allowance_remainder_until_date(datetime(2016, 5, 1)).all() + stmt = fetch_sms_free_allowance_remainder_until_date(datetime(2016, 5, 1)) + results = db.session.execute(stmt).all() assert len(results) == 2 service_result = [row for row in results if row[0] == service.id] assert service_result[0] == (service.id, 10, 2, 8) @@ -973,8 +976,8 @@ def test_fetch_usage_year_for_organization_populates_ft_billing_for_today( free_sms_fragment_limit=10, financial_year_start=current_year, ) - - assert FactBilling.query.count() == 0 + stmt = select(func.count()).select_from(FactBilling) + assert db.session.execute(stmt).scalar() == 0 create_notification(template=template, status=NotificationStatus.DELIVERED) @@ -982,7 +985,7 @@ def test_fetch_usage_year_for_organization_populates_ft_billing_for_today( organization_id=new_org.id, year=current_year ) assert len(results) == 1 - assert FactBilling.query.count() == 1 + assert db.session.execute(stmt).scalar() == 1 @freeze_time("2022-05-01 13:30") @@ -1224,8 +1227,8 @@ def test_query_organization_sms_usage_for_year_handles_multiple_services( ) # ---------- - - result = query_organization_sms_usage_for_year(org.id, 2022).all() + stmt = query_organization_sms_usage_for_year(org.id, 2022) + result = db.session.execute(stmt).all() service_1_rows = [row._asdict() for row in result if row.service_id == service_1.id] service_2_rows = [row._asdict() for row in result if row.service_id == service_2.id] @@ -1295,10 +1298,9 @@ def test_query_organization_sms_usage_for_year_handles_multiple_rates( financial_year_start=current_year, ) - result = [ - row._asdict() - for row in query_organization_sms_usage_for_year(org.id, 2022).all() - ] + stmt = query_organization_sms_usage_for_year(org.id, 2022) + rows = db.session.execute(stmt).all() + result = [row._asdict() for row in rows] # al lthe free allowance is used on the first day assert result[0]["local_date"] == date(2022, 4, 29) diff --git a/tests/app/dao/test_fact_notification_status_dao.py b/tests/app/dao/test_fact_notification_status_dao.py index 586c1c3ec..2c0de9014 100644 --- a/tests/app/dao/test_fact_notification_status_dao.py +++ b/tests/app/dao/test_fact_notification_status_dao.py @@ -3,7 +3,9 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select +from app import db from app.dao.fact_notification_status_dao import ( fetch_monthly_notification_statuses_per_service, fetch_monthly_template_usage_for_service, @@ -1126,9 +1128,10 @@ def test_update_fact_notification_status_respects_gmt_bst( process_day, NotificationType.SMS, sample_service.id ) - assert ( - FactNotificationStatus.query.filter_by( - service_id=sample_service.id, local_date=process_day - ).count() - == expected_count + stmt = ( + select(func.count()) + .select_from(FactNotificationStatus) + .filter_by(service_id=sample_service.id, local_date=process_day) ) + result = db.session.execute(stmt) + assert result.rowcount == expected_count diff --git a/tests/app/dao/test_jobs_dao.py b/tests/app/dao/test_jobs_dao.py index ca98257e5..b499faefa 100644 --- a/tests/app/dao/test_jobs_dao.py +++ b/tests/app/dao/test_jobs_dao.py @@ -4,8 +4,10 @@ import pytest from freezegun import freeze_time +from sqlalchemy import func, select from sqlalchemy.exc import IntegrityError +from app import db from app.dao.jobs_dao import ( dao_create_job, dao_get_future_scheduled_job_by_id_and_service_id, @@ -108,7 +110,8 @@ def test_should_return_notifications_only_for_this_service( def test_create_sample_job(sample_template): - assert Job.query.count() == 0 + stmt = select(func.count()).select_from(Job) + assert db.session.execute(stmt).scalar() == 0 job_id = uuid.uuid4() data = { @@ -123,9 +126,9 @@ def test_create_sample_job(sample_template): job = Job(**data) dao_create_job(job) - - assert Job.query.count() == 1 - job_from_db = Job.query.get(job_id) + stmt = select(func.count()).select_from(Job) + assert db.session.execute(stmt).scalar() == 1 + job_from_db = db.session.get(Job, job_id) assert job == job_from_db assert job_from_db.notifications_delivered == 0 assert job_from_db.notifications_failed == 0 @@ -221,7 +224,7 @@ def test_update_job(sample_job): dao_update_job(sample_job) - job_from_db = Job.query.get(sample_job.id) + job_from_db = db.session.get(Job, sample_job.id) assert job_from_db.job_status == JobStatus.IN_PROGRESS diff --git a/tests/app/job/test_rest.py b/tests/app/job/test_rest.py index 6d4112058..8d40a045a 100644 --- a/tests/app/job/test_rest.py +++ b/tests/app/job/test_rest.py @@ -837,7 +837,7 @@ def test_get_jobs_should_paginate(admin_request, sample_template): assert resp_json["page_size"] == 2 assert resp_json["total"] == 10 assert "links" in resp_json - assert set(resp_json["links"].keys()) == {"next", "last"} + assert set(resp_json["links"].keys()) == {"next", "last", "prev"} def test_get_jobs_accepts_page_parameter(admin_request, sample_template):