diff --git a/src/curate_gpt/cli.py b/src/curate_gpt/cli.py index 43bbbae..f66b089 100644 --- a/src/curate_gpt/cli.py +++ b/src/curate_gpt/cli.py @@ -5,6 +5,7 @@ import logging import sys from pathlib import Path +import random from typing import Any, Dict, List, Union import click @@ -17,6 +18,7 @@ from llm.cli import load_conversation from oaklib import get_adapter from pydantic import BaseModel +from tqdm import tqdm from curate_gpt import ChromaDBAdapter, __version__ from curate_gpt.agents.chat_agent import ChatAgent, ChatResponse @@ -37,6 +39,8 @@ from curate_gpt.wrappers.literature.pubmed_wrapper import PubmedWrapper from curate_gpt.wrappers.ontology import OntologyWrapper +from oaklib.datamodels.vocabulary import IS_A, PART_OF + __all__ = [ "main", ] @@ -1595,6 +1599,64 @@ def _text_lookup(obj: Dict): db.update_collection_metadata(collection, object_type="OntologyClass") +@ontology.command(name="subsumption") +@path_option +@collection_option +@model_option +@click.option("--prefix", required=False, default=None, help="Prefix of terms to use, e.g. 'HP:'") +@click.option('--predicates', multiple=True, help='Predicates of interest (e.g., is_a, part_of)') +@click.option("--seed", required=False, default=42, help="Seed for random number generator") +@click.option('--num_terms', required=False, default=1000, help='Number of term pairs to compare') +@click.argument("ont") +def subsumption_command(ont, path, collection, prefix, predicates, seed, num_terms, model, **kwargs): + """ + Compare pairs of ontology terms where one subsumes the other, or one does NOT + subsume the other, to determine whether LLM embeddings reflect subsumption + relationships. + + Example: + ------- + curategpt subsumption -c obo_hp $db/hp.db + + """ + if not predicates: + predicates = [IS_A, PART_OF] + + oak_adapter = get_adapter(ont) + view = OntologyWrapper(oak_adapter=oak_adapter) + db = ChromaDBAdapter(path, **kwargs) + db.text_lookup = view.text_field + + c = db.client.get_collection(collection) + + # get all terms + terms = list(view.oak_adapter.all_entity_curies()) + if prefix is not None: + terms = [t for t in terms if t.startswith(prefix)] + if not terms: + raise ValueError(f"No terms found with prefix {prefix}") + + # choose 1000 pseudo-random terms, get ancestor info, choose a random subsuming + # and non-subsuming term, calculate fraction of ancestors in common while we are + # at it + random.seed(seed) + ancs = [] + random_pairs = [] + for term in tqdm(random.sample(terms, num_terms), desc="Choosing terms to compare"): + anc = list(view.oak_adapter.ancestors(term, predicates=predicates, reflexive=False)) + ancs.append((term, anc)) + + # choose random term to pair with + random_other_term = random.choice(terms) + random_term_ancs = list(view.oak_adapter.ancestors(random_other_term, + predicates=predicates, + reflexive=False)) + pair_shared_anc = len(set(anc).intersection( + set(random_term_ancs))) / len(anc) # fraction of ancestors in common + random_pairs.append((term, random_other_term, pair_shared_anc)) + + pass + @main.group() def view(): "Virtual store/wrapper"