Skip to content

Commit

Permalink
Merge pull request #337 from smart-on-fhir/mikix/gpt
Browse files Browse the repository at this point in the history
Add two new covid_symptom GPT tasks
  • Loading branch information
mikix authored Aug 9, 2024
2 parents d5662b9 + 8978f1c commit 6f18af2
Show file tree
Hide file tree
Showing 17 changed files with 658 additions and 488 deletions.
437 changes: 0 additions & 437 deletions .pylintrc

This file was deleted.

8 changes: 6 additions & 2 deletions compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ services:
context: .
target: cumulus-etl
environment:
# Environment variobles to pull in from the host
- AWS_ACCESS_KEY_ID
- AWS_DEFAULT_PROFILE
- AWS_PROFILE
- AWS_SECRET_ACCESS_KEY
- AWS_SESSION_TOKEN
- AWS_PROFILE
- AWS_DEFAULT_PROFILE
- AZURE_OPENAI_API_KEY
- AZURE_OPENAI_ENDPOINT
# Internal environment variobles
- CUMULUS_HUGGING_FACE_URL=http://llama2:8086/
- URL_CTAKES_REST=http://ctakes-covid:8080/ctakes-web-rest/service/analyze
- URL_CNLP_NEGATION=http://cnlpt-negation:8000/negation/process
Expand Down
7 changes: 6 additions & 1 deletion cumulus_etl/etl/studies/covid_symptom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""The covid_symptom study"""

from .covid_tasks import CovidSymptomNlpResultsTask, CovidSymptomNlpResultsTermExistsTask
from .covid_tasks import (
CovidSymptomNlpResultsGpt4Task,
CovidSymptomNlpResultsGpt35Task,
CovidSymptomNlpResultsTask,
CovidSymptomNlpResultsTermExistsTask,
)
15 changes: 6 additions & 9 deletions cumulus_etl/etl/studies/covid_symptom/covid_ctakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx
from ctakesclient.transformer import TransformerModel

from cumulus_etl import common, fhir, nlp, store
from cumulus_etl import common, nlp, store


async def covid_symptoms_extract(
Expand All @@ -31,14 +31,11 @@ async def covid_symptoms_extract(
:param cnlp_http_client: HTTPX client to use for the cNLP transformer server
:return: list of NLP results encoded as FHIR observations
"""
docref_id = docref["id"]
_, subject_id = fhir.unref_resource(docref["subject"])

encounters = docref.get("context", {}).get("encounter", [])
if not encounters:
logging.warning("No encounters for docref %s", docref_id)
try:
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
except KeyError as exc:
logging.warning(exc)
return None
_, encounter_id = fhir.unref_resource(encounters[0])

# cTAKES cache namespace history (and thus, cache invalidation history):
# v1: original cTAKES processing
Expand All @@ -54,7 +51,7 @@ async def covid_symptoms_extract(
case TransformerModel.TERM_EXISTS:
cnlp_namespace = f"{ctakes_namespace}-cnlp_term_exists_v1"
case _:
logging.warning("Unknown polarity method: %s", polarity_model.value)
logging.warning("Unknown polarity method: %s", polarity_model)
return None

timestamp = common.datetime_now().isoformat()
Expand Down
225 changes: 215 additions & 10 deletions cumulus_etl/etl/studies/covid_symptom/covid_tasks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
"""Define tasks for the covid_symptom study"""

import itertools
import json
import logging
import os
from typing import ClassVar

import ctakesclient
import openai
import pyarrow
import rich.progress
from ctakesclient.transformer import TransformerModel
from openai.types import chat

from cumulus_etl import nlp, store
from cumulus_etl import common, nlp, store
from cumulus_etl.etl import tasks
from cumulus_etl.etl.studies.covid_symptom import covid_ctakes

Expand Down Expand Up @@ -74,9 +79,11 @@ def is_ed_docref(docref):
return any(is_ed_coding(x) for x in codings)


class BaseCovidSymptomNlpResultsTask(tasks.BaseNlpTask):
class BaseCovidCtakesTask(tasks.BaseNlpTask):
"""Covid Symptom study task, to generate symptom lists from ED notes using cTAKES + a polarity check"""

tags: ClassVar = {"covid_symptom", "gpu"}

# Subclasses: set name, tags, and polarity_model yourself
polarity_model = None

Expand Down Expand Up @@ -117,7 +124,7 @@ async def prepare_task(self) -> bool:
bsv_path = ctakesclient.filesystem.covid_symptoms_path()
success = nlp.restart_ctakes_with_bsv(self.task_config.ctakes_overrides, bsv_path)
if not success:
print(f"Skipping {self.name}.")
print(" Skipping.")
self.summaries[0].had_errors = True
return success

Expand Down Expand Up @@ -197,11 +204,10 @@ def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Sche
)


class CovidSymptomNlpResultsTask(BaseCovidSymptomNlpResultsTask):
class CovidSymptomNlpResultsTask(BaseCovidCtakesTask):
"""Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt negation"""

name: ClassVar = "covid_symptom__nlp_results"
tags: ClassVar = {"covid_symptom", "gpu"}
polarity_model: ClassVar = TransformerModel.NEGATION

@classmethod
Expand All @@ -210,17 +216,216 @@ async def init_check(cls) -> None:
nlp.check_negation_cnlpt()


class CovidSymptomNlpResultsTermExistsTask(BaseCovidSymptomNlpResultsTask):
class CovidSymptomNlpResultsTermExistsTask(BaseCovidCtakesTask):
"""Covid Symptom study task, to generate symptom lists from ED notes using cTAKES and cnlpt termexists"""

name: ClassVar = "covid_symptom__nlp_results_term_exists"
polarity_model: ClassVar = TransformerModel.TERM_EXISTS

# Explicitly don't use any tags because this is really a "hidden" task that is mostly for comparing
# polarity model performance more than running a study. So we don't want it to be accidentally run.
tags: ClassVar = {}

@classmethod
async def init_check(cls) -> None:
nlp.check_ctakes()
nlp.check_term_exists_cnlpt()


class BaseCovidGptTask(tasks.BaseNlpTask):
"""Covid Symptom study task, using GPT"""

tags: ClassVar = {"covid_symptom", "cpu"}
outputs: ClassVar = [tasks.OutputTable(resource_type=None)]

# Overridden by child classes
model_id: ClassVar = None

async def prepare_task(self) -> bool:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if not api_key or not endpoint:
if not api_key:
print(" The AZURE_OPENAI_API_KEY environment variable is not set.")
if not endpoint:
print(" The AZURE_OPENAI_ENDPOINT environment variable is not set.")
print(" Skipping.")
self.summaries[0].had_errors = True
return False
return True

async def read_entries(self, *, progress: rich.progress.Progress = None) -> tasks.EntryIterator:
"""Passes clinical notes through NLP and returns any symptoms found"""
async for orig_docref, docref, clinical_note in self.read_notes(
progress=progress, doc_check=is_ed_docref
):
try:
docref_id, encounter_id, subject_id = nlp.get_docref_info(docref)
except KeyError as exc:
logging.warning(exc)
self.add_error(orig_docref)
continue

client = openai.AsyncAzureOpenAI(api_version="2024-06-01")
try:
response = await nlp.cache_wrapper(
self.task_config.dir_phi,
f"{self.name}_v{self.task_version}",
clinical_note,
lambda x: chat.ChatCompletion.model_validate_json(x),
lambda x: x.model_dump_json(
indent=None, round_trip=True, exclude_unset=True, by_alias=True
),
client.chat.completions.create,
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": self.get_prompt(clinical_note)},
],
model=self.model_id,
seed=12345, # arbitrary, only specified to improve reproducibility
response_format={"type": "json_object"},
)
except openai.APIError as exc:
logging.warning(f"Could not connect to GPT for DocRef {docref['id']}: {exc}")
self.add_error(orig_docref)
continue

if response.choices[0].finish_reason != "stop":
logging.warning(
f"GPT response didn't complete for DocRef {docref['id']}: "
f"{response.choices[0].finish_reason}"
)
self.add_error(orig_docref)
continue

try:
symptoms = json.loads(response.choices[0].message.content)
except json.JSONDecodeError as exc:
logging.warning(f"Could not parse GPT results for DocRef {docref['id']}: {exc}")
self.add_error(orig_docref)
continue

yield {
"id": docref_id, # keep one results entry per docref
"docref_id": docref_id,
"encounter_id": encounter_id,
"subject_id": subject_id,
"generated_on": common.datetime_now().isoformat(),
"task_version": self.task_version,
"system_fingerprint": response.system_fingerprint,
"symptoms": {
"congestion_or_runny_nose": bool(symptoms.get("Congestion or runny nose")),
"cough": bool(symptoms.get("Cough")),
"diarrhea": bool(symptoms.get("Diarrhea")),
"dyspnea": bool(symptoms.get("Dyspnea")),
"fatigue": bool(symptoms.get("Fatigue")),
"fever_or_chills": bool(symptoms.get("Fever or chills")),
"headache": bool(symptoms.get("Headache")),
"loss_of_taste_or_smell": bool(symptoms.get("Loss of taste or smell")),
"muscle_or_body_aches": bool(symptoms.get("Muscle or body aches")),
"nausea_or_vomiting": bool(symptoms.get("Nausea or vomiting")),
"sore_throat": bool(symptoms.get("Sore throat")),
},
}

@staticmethod
def get_prompt(clinical_note: str) -> str:
instructions = (
"You are a helpful assistant identifying symptoms from emergency "
"department notes that could relate to infectious respiratory diseases.\n"
"Output positively documented symptoms, looking out specifically for the "
"following: Congestion or runny nose, Cough, Diarrhea, Dyspnea, Fatigue, "
"Fever or chills, Headache, Loss of taste or smell, Muscle or body aches, "
"Nausea or vomiting, Sore throat.\nSymptoms only need to be positively "
"mentioned once to be included.\nDo not mention symptoms that are not "
"present in the note.\n\nFollow these rules:\nRule (1): Symptoms must be "
"positively documented and relevant to the presenting illness or reason "
"for visit.\nRule (2): Medical section headings must be specific to the "
"present emergency department encounter.\nInclude positive symptoms from "
'these medical section headings: "Chief Complaint", "History of '
'Present Illness", "HPI", "Review of Systems", "Physical Exam", '
'"Vital Signs", "Assessment and Plan", "Medical Decision Making".\n'
"Rule (3): Positive symptom mentions must be a definite medical synonym.\n"
'Include positive mentions of: "anosmia", "loss of taste", "loss of '
'smell", "rhinorrhea", "congestion", "discharge", "nose is '
'dripping", "runny nose", "stuffy nose", "cough", "tussive or '
'post-tussive", "cough is unproductive", "productive cough", "dry '
'cough", "wet cough", "producing sputum", "diarrhea", "watery '
'stool", "fatigue", "tired", "exhausted", "weary", "malaise", '
'"feeling generally unwell", "fever", "pyrexia", "chills", '
'"temperature greater than or equal 100.4 Fahrenheit or 38 celsius", '
'"Temperature >= 100.4F", "Temperature >= 38C", "headache", "HA", '
'"migraine", "cephalgia", "head pain", "muscle or body aches", '
'"muscle aches", "generalized aches and pains", "body aches", '
'"myalgias", "myoneuralgia", "soreness", "generalized aches and '
'pains", "nausea or vomiting", "Nausea", "vomiting", "emesis", '
'"throwing up", "queasy", "regurgitated", "shortness of breath", '
'"difficulty breathing", "SOB", "Dyspnea", "breathing is short", '
'"increased breathing", "labored breathing", "distressed '
'breathing", "sore throat", "throat pain", "pharyngeal pain", '
'"pharyngitis", "odynophagia".\nYour reply must be parsable as JSON.\n'
'Format your response using only the following JSON schema: {"Congestion '
'or runny nose": boolean, "Cough": boolean, "Diarrhea": boolean, '
'"Dyspnea": boolean, "Fatigue": boolean, "Fever or chills": '
'boolean, "Headache": boolean, "Loss of taste or smell": boolean, '
'"Muscle or body aches": boolean, "Nausea or vomiting": boolean, '
'"Sore throat": boolean}. Each JSON key should correspond to a symptom, '
"and each value should be true if that symptom is indicated in the "
"clinical note; false otherwise.\nNever explain yourself, and only reply "
"with JSON."
)
return f"### Instructions ###\n{instructions}\n### Text ###\n{clinical_note}"

@classmethod
def get_schema(cls, resource_type: str | None, rows: list[dict]) -> pyarrow.Schema:
return pyarrow.schema(
[
pyarrow.field("id", pyarrow.string()),
pyarrow.field("docref_id", pyarrow.string()),
pyarrow.field("encounter_id", pyarrow.string()),
pyarrow.field("subject_id", pyarrow.string()),
pyarrow.field("generated_on", pyarrow.string()),
pyarrow.field("task_version", pyarrow.int32()),
pyarrow.field("system_fingerprint", pyarrow.string()),
pyarrow.field(
"symptoms",
pyarrow.struct(
[
pyarrow.field("congestion_or_runny_nose", pyarrow.bool_()),
pyarrow.field("cough", pyarrow.bool_()),
pyarrow.field("diarrhea", pyarrow.bool_()),
pyarrow.field("dyspnea", pyarrow.bool_()),
pyarrow.field("fatigue", pyarrow.bool_()),
pyarrow.field("fever_or_chills", pyarrow.bool_()),
pyarrow.field("headache", pyarrow.bool_()),
pyarrow.field("loss_of_taste_or_smell", pyarrow.bool_()),
pyarrow.field("muscle_or_body_aches", pyarrow.bool_()),
pyarrow.field("nausea_or_vomiting", pyarrow.bool_()),
pyarrow.field("sore_throat", pyarrow.bool_()),
],
),
),
]
)


class CovidSymptomNlpResultsGpt35Task(BaseCovidGptTask):
"""Covid Symptom study task, using GPT3.5"""

name: ClassVar = "covid_symptom__nlp_results_gpt35"
model_id: ClassVar = "gpt-35-turbo-0125"

task_version: ClassVar = 1
# Task Version History:
# ** 1 (2024-08): Initial version **
# model: gpt-35-turbo-0125
# seed: 12345


class CovidSymptomNlpResultsGpt4Task(BaseCovidGptTask):
"""Covid Symptom study task, using GPT4"""

name: ClassVar = "covid_symptom__nlp_results_gpt4"
model_id: ClassVar = "gpt-4"

task_version: ClassVar = 1
# Task Version History:
# ** 1 (2024-08): Initial version **
# model: gpt-4
# seed: 12345
4 changes: 3 additions & 1 deletion cumulus_etl/etl/studies/hftest/hf_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ async def read_entries(self, *, progress: rich.progress.Progress = None) -> task
summary = await nlp.cache_wrapper(
self.task_config.dir_phi,
f"{self.name}_v{self.task_version}",
user_prompt,
clinical_note,
lambda x: x, # from file: just store the string
lambda x: x, # to file: just read it back
nlp.llama2_prompt,
system_prompt,
user_prompt,
Expand Down
2 changes: 2 additions & 0 deletions cumulus_etl/etl/tasks/task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def get_all_tasks() -> list[type[AnyTask]]:
# Note: tasks will be run in the order listed here.
return [
*get_default_tasks(),
covid_symptom.CovidSymptomNlpResultsGpt35Task,
covid_symptom.CovidSymptomNlpResultsGpt4Task,
covid_symptom.CovidSymptomNlpResultsTask,
covid_symptom.CovidSymptomNlpResultsTermExistsTask,
hftest.HuggingFaceTestTask,
Expand Down
2 changes: 1 addition & 1 deletion cumulus_etl/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .extract import TransformerModel, ctakes_extract, ctakes_httpx_client, list_polarity
from .huggingface import hf_info, hf_prompt, llama2_prompt
from .utils import cache_wrapper, is_docref_valid
from .utils import cache_wrapper, get_docref_info, is_docref_valid
from .watcher import (
check_ctakes,
check_negation_cnlpt,
Expand Down
Loading

0 comments on commit 6f18af2

Please sign in to comment.