Skip to content

Commit

Permalink
Merge branch 'main' into 319-remove-obsolete-re-iding-code
Browse files Browse the repository at this point in the history
  • Loading branch information
mbthornton-lbl committed Jan 9, 2025
2 parents b32d0ba + 533932c commit bfc17b2
Show file tree
Hide file tree
Showing 13 changed files with 775 additions and 498 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/blt.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ jobs:
- name: Test with pytest
run: |
poetry run pytest ./tests --junit-xml=pytest.xml --cov-report=term \
poetry run pytest -m "not integration" ./tests --junit-xml=pytest.xml --cov-report=term \
--cov-report=xml --cov=nmdc_automation --local-badge-output-dir badges/
101 changes: 83 additions & 18 deletions nmdc_automation/api/nmdcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from datetime import datetime, timedelta, timezone
from nmdc_automation.config import SiteConfig, UserConfig
import logging
from tenacity import retry, wait_exponential, stop_after_attempt

logging_level = os.getenv("NMDC_LOG_LEVEL", logging.DEBUG)
logging.basicConfig(
level=logging_level, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)

SECONDS_IN_DAY = 86400

def _get_sha256(fn: Union[str, Path]) -> str:
"""
Expand Down Expand Up @@ -45,7 +53,7 @@ def expiry_dt_from_now(days=0, hours=0, minutes=0, seconds=0):

class NmdcRuntimeApi:
token = None
expires = 0
expires_at = 0
_base_url = None
client_id = None
client_secret = None
Expand All @@ -63,15 +71,17 @@ def __init__(self, site_configuration: Union[str, Path, SiteConfig]):
def refresh_token(func):
def _get_token(self, *args, **kwargs):
# If it expires in 60 seconds, refresh
if not self.token or self.expires + 60 > time():
if not self.token or self.expires_at < time() + 60:
self.get_token()
return func(self, *args, **kwargs)

return _get_token

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
def get_token(self):
"""
Get a token using a client id/secret.
Retries up to 6 times with exponential backoff.
"""
h = {
"accept": "application/json",
Expand All @@ -83,21 +93,39 @@ def get_token(self):
"client_secret": self.client_secret,
}
url = self._base_url + "token"
resp = requests.post(url, headers=h, data=data).json()
expt = resp["expires"]
self.expires = time() + expt["minutes"] * 60

self.token = resp["access_token"]
resp = requests.post(url, headers=h, data=data)
if not resp.ok:
logging.error(f"Failed to get token: {resp.text}")
resp.raise_for_status()
response_body = resp.json()

# Expires can be in days, hours, minutes, seconds - sum them up and convert to seconds
expires = 0
if "days" in response_body["expires"]:
expires += int(response_body["expires"]["days"]) * SECONDS_IN_DAY
if "hours" in response_body["expires"]:
expires += int(response_body["expires"]["hours"]) * 3600
if "minutes" in response_body["expires"]:
expires += int(response_body["expires"]["minutes"]) * 60
if "seconds" in response_body["expires"]:
expires += int(response_body["expires"]["seconds"])

self.expires_at = time() + expires

self.token = response_body["access_token"]
self.header = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": "Bearer %s" % (self.token),
}
return resp
logging.debug(f"New token expires at {self.expires_at}")
return response_body

def get_header(self):
return self.header

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def minter(self, id_type, informed_by=None):
url = f"{self._base_url}pids/mint"
Expand All @@ -115,6 +143,7 @@ def minter(self, id_type, informed_by=None):
raise ValueError("Failed to bind metadata to pid")
return id

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def mint(self, ns, typ, ct):
"""
Expand All @@ -127,15 +156,20 @@ def mint(self, ns, typ, ct):
url = self._base_url + "ids/mint"
d = {"populator": "", "naa": ns, "shoulder": typ, "number": ct}
resp = requests.post(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_object(self, obj, decode=False):
"""
Helper function to get object info
"""
url = "%sobjects/%s" % (self._base_url, obj)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status()
data = resp.json()
if decode and "description" in data:
try:
Expand All @@ -145,6 +179,8 @@ def get_object(self, obj, decode=False):

return data


@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def create_object(self, fn, description, dataurl):
"""
Expand Down Expand Up @@ -186,8 +222,11 @@ def create_object(self, fn, description, dataurl):
"self_uri": "todo",
}
resp = requests.post(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def post_objects(self, obj_data):
url = self._base_url + "workflows/workflow_executions"
Expand All @@ -197,23 +236,29 @@ def post_objects(self, obj_data):
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def set_type(self, obj, typ):
url = "%sobjects/%s/types" % (self._base_url, obj)
d = [typ]
resp = requests.put(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def bump_time(self, obj):
url = "%sobjects/%s" % (self._base_url, obj)
now = datetime.today().isoformat()

d = {"created_time": now}
resp = requests.patch(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

# TODO test that this concatenates multi-page results
@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_jobs(self, filt=None, max=100) -> List[dict]:
url = "%sjobs?max_page_size=%s" % (self._base_url, max)
Expand All @@ -223,25 +268,36 @@ def list_jobs(self, filt=None, max=100) -> List[dict]:
orig_url = url
results = []
while True:
resp = requests.get(url, data=json.dumps(d), headers=self.header).json()
if "resources" not in resp:
logging.warning(str(resp))
resp = requests.get(url, data=json.dumps(d), headers=self.header)
if resp.status_code != 200:
resp.raise_for_status()
try:
response_json = resp.json()
except Exception as e:
logging.error(f"Failed to parse response: {resp.text}")
raise e
if "resources" not in response_json:
logging.warning(str(response_json))
break
results.extend(resp["resources"])
if "next_page_token" not in resp or not resp["next_page_token"]:
results.extend(response_json["resources"])
if "next_page_token" not in response_json or not response_json["next_page_token"]:
break
url = orig_url + "&page_token=%s" % (resp["next_page_token"])
url = orig_url + "&page_token=%s" % (response_json["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_job(self, job):
url = "%sjobs/%s" % (self._base_url, job)
def get_job(self, job_id: str):
url = "%sjobs/%s" % (self._base_url, job_id)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def claim_job(self, job):
url = "%sjobs/%s:claim" % (self._base_url, job)
def claim_job(self, job_id: str):
url = "%sjobs/%s:claim" % (self._base_url, job_id)
resp = requests.post(url, headers=self.header)
if resp.status_code == 409:
claimed = True
Expand All @@ -265,6 +321,7 @@ def _page_query(self, url):
url = orig_url + "&page_token=%s" % (resp["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_objs(self, filt=None, max_page_size=40):
url = "%sobjects?max_page_size=%d" % (self._base_url, max_page_size)
Expand All @@ -273,6 +330,7 @@ def list_objs(self, filt=None, max_page_size=40):
results = self._page_query(url)
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_ops(self, filt=None, max_page_size=40):
url = "%soperations?max_page_size=%d" % (self._base_url, max_page_size)
Expand All @@ -292,12 +350,16 @@ def list_ops(self, filt=None, max_page_size=40):
url = orig_url + "&page_token=%s" % (resp["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_op(self, opid):
url = "%soperations/%s" % (self._base_url, opid)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def update_op(self, opid, done=None, results=None, meta=None):
"""
Expand All @@ -320,8 +382,11 @@ def update_op(self, opid, done=None, results=None, meta=None):
d["metadata"] = cur["metadata"]
d["metadata"]["extra"] = meta
resp = requests.patch(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def run_query(self, query):
url = "%squeries:run" % self._base_url
Expand Down
8 changes: 7 additions & 1 deletion nmdc_automation/run_process/run_import.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import click
import csv
import gc
import importlib.resources
from functools import lru_cache
import logging
Expand All @@ -12,7 +13,7 @@
from nmdc_automation.api import NmdcRuntimeApi


logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,6 +71,8 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
# validate the database
logger.info("Validating imported data")
db_dict = yaml.safe_load(yaml_dumper.dumps(db))
del db # free up memory
del do_mapping # free up memory
validation_report = linkml.validator.validate(db_dict, nmdc_materialized)
if validation_report.results:
logger.error(f"Validation Failed")
Expand All @@ -93,10 +96,13 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
logger.info("Posting data to the API")
try:
runtime.post_objects(db_dict)
del db_dict # free up memory
except Exception as e:
logger.error(f"Error posting data to the API: {e}")
raise e

gc.collect()




Expand Down
25 changes: 15 additions & 10 deletions nmdc_automation/workflow_automation/watch_nmdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import importlib.resources
from functools import lru_cache
import traceback
import os

from nmdc_schema.nmdc import Database
from nmdc_automation.api import NmdcRuntimeApi
Expand All @@ -23,8 +24,12 @@
DEFAULT_STATE_DIR = Path(__file__).parent / "_state"
DEFAULT_STATE_FILE = DEFAULT_STATE_DIR / "state.json"
INITIAL_STATE = {"jobs": []}

logging_level = os.getenv("NMDC_LOG_LEVEL", logging.DEBUG)
logging.basicConfig(
level=logging_level, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class FileHandler:
Expand Down Expand Up @@ -236,14 +241,13 @@ def process_successful_job(self, job: WorkflowJob) -> Database:

database = Database()

# get job runner metadata if needed
if not job.job.metadata:
logger.info(f"Getting job runner metadata for job {job.workflow.job_runner_id}")
job.job.job_id = job.workflow.job_runner_id
metadata = job.job.get_job_metadata()
m_dict = yaml.safe_load(yaml_dumper.dumps(metadata))
logger.debug(f"Job runner metadata: {m_dict}")
job.job.metadata = metadata
# Upate the job metadata
logger.info(f"Getting job runner metadata for job {job.workflow.job_runner_id}")
job.job.job_id = job.workflow.job_runner_id
metadata = job.job.get_job_metadata()
m_dict = yaml.safe_load(yaml_dumper.dumps(metadata))
logger.debug(f"Job runner metadata: {m_dict}")
job.job.metadata = metadata

data_objects = job.make_data_objects(output_dir=output_path)
if not data_objects:
Expand All @@ -264,6 +268,7 @@ def process_successful_job(self, job: WorkflowJob) -> Database:
logger.info(f"Created workflow execution record for job {job.opid}")

job.done = True
job.workflow.state["end"] = workflow_execution.ended_at_time
self.file_handler.write_metadata_if_not_exists(job)
self.save_checkpoint()
return database
Expand Down Expand Up @@ -300,7 +305,7 @@ def get_unclaimed_jobs(self, allowed_workflows) -> List[WorkflowJob]:
"workflow.id": {"$in": allowed_workflows},
"claims": {"$size": 0}
}
job_records = self.runtime_api.list_jobs(filt=filt)
job_records = self.runtime_api.list_jobs(filt=filt)

for job in job_records:
jobs.append(WorkflowJob(self.config, workflow_state=job))
Expand Down
Loading

0 comments on commit bfc17b2

Please sign in to comment.