Skip to content

Commit

Permalink
Merge pull request #154 from QuanMPhm/migrate_fos
Browse files Browse the repository at this point in the history
Added management command to migrate field of sciences
  • Loading branch information
knikolla authored Apr 12, 2024
2 parents b4464cf + 5b07de8 commit f2db32e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import csv

from django.core.management.base import BaseCommand
from coldfront.core.project.models import Project
from coldfront.core.field_of_science.models import FieldOfScience

import logging


logger = logging.getLogger(__name__)


class Command(BaseCommand):
help = """Migrates Coldfront's list of fields of sciences (FOS), changing what
FOS a project can be assigned to, and updating the FOS in all existing projects.
Requires a csv, tab-seperated, containing two columns, the first containing the list of old FOS,
the second containing the new FOS that the old FOS will map onto.
I.e to map 'Quantum Mechanics' and 'Photonics' to 'Physics', provide this csv:
Quantum Mechanics Physics
Photonics Physics
"""

def add_arguments(self, parser):
parser.add_argument(
"-m",
"--mapping",
required=True,
help="required tab-seperated csv file to provide mapping for migration",
)

def handle(self, *args, **options):
mapping_csv = options["mapping"]
mapping_dict, new_fos_set = self._load_fos_map(mapping_csv)
self._validate_old_fos(mapping_dict)
self._create_new_fos(new_fos_set)
self._migrate_fos(mapping_dict)

logger.info("Field of science migration completed!")

@staticmethod
def _load_fos_map(mapping_csv):
mapping_dict = dict()
new_fos_set = set()
with open(mapping_csv, "r") as f:
rd = csv.reader(f, delimiter="\t")
for row in rd:
old_fos, new_fos = row
mapping_dict[old_fos] = new_fos
new_fos_set.add(new_fos)

return (mapping_dict, new_fos_set)

@staticmethod
def _validate_old_fos(mapping_dict):
for old_fos_name in list(mapping_dict.keys()):
if not FieldOfScience.objects.filter(description=old_fos_name):
logger.warn(f"Old field of science {old_fos_name} does not exist")

@staticmethod
def _create_new_fos(new_fos_set):
for new_fos_name in new_fos_set:
FieldOfScience.objects.get_or_create(
is_selectable=True,
description=new_fos_name,
)

def _migrate_fos(self, mapping_dict):
for project in Project.objects.all():
cur_fos_name = project.field_of_science.description
if cur_fos_name in mapping_dict.keys():
new_fos_name = mapping_dict[cur_fos_name]
new_fos = FieldOfScience.objects.get(description=new_fos_name)
project.field_of_science = new_fos
project.save()
logger.info(
f"Migrated field of science for project {project.pk} from {cur_fos_name} to {new_fos_name}"
)

for old_fos_name in list(mapping_dict.keys()):
FieldOfScience.objects.get(description=old_fos_name).delete()
8 changes: 8 additions & 0 deletions src/coldfront_plugin_cloud/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ResourceType,
ResourceAttribute,
ResourceAttributeType)
from coldfront.core.field_of_science.models import FieldOfScience
from django.core.management import call_command

from coldfront_plugin_cloud import attributes
Expand Down Expand Up @@ -105,3 +106,10 @@ def new_allocation_user(self, allocation, user):
status=AllocationUserStatusChoice.objects.get(name='Active')
)
return au

def new_field_of_science(self, description=None):
description = description or uuid.uuid4().hex
fos, _ = FieldOfScience.objects.get_or_create(
is_selectable=True, description=description
)
return fos
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import uuid
import tempfile

from django.core.management import call_command
from coldfront.core.project.models import Project
from coldfront.core.field_of_science.models import FieldOfScience

from coldfront_plugin_cloud.tests import base


class TestFixAllocation(base.TestBase):
def test_command_output(self):
old_fos_1 = self.new_field_of_science()
old_fos_2 = self.new_field_of_science()
old_fos_3 = self.new_field_of_science()
old_fos_4 = self.new_field_of_science()

new_fos_1_des = uuid.uuid4().hex # Migrate to new fos
new_fos_2_des = old_fos_4.description # Migrate to existing fos

fake_project_1 = self.new_project()
fake_project_2 = self.new_project()
fake_project_3 = self.new_project()
fake_project_1.field_of_science = old_fos_1
fake_project_2.field_of_science = old_fos_2
fake_project_3.field_of_science = old_fos_3
fake_project_1.save()
fake_project_2.save()
fake_project_3.save()

temp_csv = tempfile.NamedTemporaryFile(mode="w+")
temp_csv.write(f"{old_fos_1.description}\t{new_fos_1_des}\n")
temp_csv.write(f"{old_fos_2.description}\t{new_fos_2_des}\n")
temp_csv.write(f"{old_fos_3.description}\t{new_fos_2_des}\n")
temp_csv.seek(0)

n_fos = FieldOfScience.objects.all().count()
call_command("migrate_fields_of_science", "-m", temp_csv.name)

self.assertEqual(n_fos - 2, FieldOfScience.objects.all().count())

# Assert project fos name replaced
fake_project_1 = Project.objects.get(pk=fake_project_1.pk)
fake_project_2 = Project.objects.get(pk=fake_project_2.pk)
fake_project_3 = Project.objects.get(pk=fake_project_3.pk)
self.assertEqual(fake_project_1.field_of_science.description, new_fos_1_des)
self.assertEqual(fake_project_2.field_of_science.description, new_fos_2_des)
self.assertEqual(fake_project_3.field_of_science.description, new_fos_2_des)

# Assert old fos no longer exists
self.assertFalse(
FieldOfScience.objects.filter(description=old_fos_1.description)
)
self.assertFalse(
FieldOfScience.objects.filter(description=old_fos_2.description)
)
self.assertFalse(
FieldOfScience.objects.filter(description=old_fos_3.description)
)

0 comments on commit f2db32e

Please sign in to comment.