Skip to content

Commit

Permalink
Add subsumption command (this should be moved to some module I guess)
Browse files Browse the repository at this point in the history
  • Loading branch information
Justin Reese committed Feb 8, 2024
1 parent bebe6bb commit 7ada397
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions src/curate_gpt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import sys
from pathlib import Path
import random
from typing import Any, Dict, List, Union

import click
Expand All @@ -17,6 +18,7 @@
from llm.cli import load_conversation
from oaklib import get_adapter
from pydantic import BaseModel
from tqdm import tqdm

from curate_gpt import ChromaDBAdapter, __version__
from curate_gpt.agents.chat_agent import ChatAgent, ChatResponse
Expand All @@ -37,6 +39,8 @@
from curate_gpt.wrappers.literature.pubmed_wrapper import PubmedWrapper
from curate_gpt.wrappers.ontology import OntologyWrapper

from oaklib.datamodels.vocabulary import IS_A, PART_OF

__all__ = [
"main",
]
Expand Down Expand Up @@ -1595,6 +1599,64 @@ def _text_lookup(obj: Dict):
db.update_collection_metadata(collection, object_type="OntologyClass")


@ontology.command(name="subsumption")
@path_option
@collection_option
@model_option
@click.option("--prefix", required=False, default=None, help="Prefix of terms to use, e.g. 'HP:'")
@click.option('--predicates', multiple=True, help='Predicates of interest (e.g., is_a, part_of)')
@click.option("--seed", required=False, default=42, help="Seed for random number generator")
@click.option('--num_terms', required=False, default=1000, help='Number of term pairs to compare')
@click.argument("ont")
def subsumption_command(ont, path, collection, prefix, predicates, seed, num_terms, model, **kwargs):
"""
Compare pairs of ontology terms where one subsumes the other, or one does NOT
subsume the other, to determine whether LLM embeddings reflect subsumption
relationships.
Example:
-------
curategpt subsumption -c obo_hp $db/hp.db
"""
if not predicates:
predicates = [IS_A, PART_OF]

oak_adapter = get_adapter(ont)
view = OntologyWrapper(oak_adapter=oak_adapter)
db = ChromaDBAdapter(path, **kwargs)
db.text_lookup = view.text_field

c = db.client.get_collection(collection)

# get all terms
terms = list(view.oak_adapter.all_entity_curies())
if prefix is not None:
terms = [t for t in terms if t.startswith(prefix)]
if not terms:
raise ValueError(f"No terms found with prefix {prefix}")

# choose 1000 pseudo-random terms, get ancestor info, choose a random subsuming
# and non-subsuming term, calculate fraction of ancestors in common while we are
# at it
random.seed(seed)
ancs = []
random_pairs = []
for term in tqdm(random.sample(terms, num_terms), desc="Choosing terms to compare"):
anc = list(view.oak_adapter.ancestors(term, predicates=predicates, reflexive=False))
ancs.append((term, anc))

# choose random term to pair with
random_other_term = random.choice(terms)
random_term_ancs = list(view.oak_adapter.ancestors(random_other_term,
predicates=predicates,
reflexive=False))
pair_shared_anc = len(set(anc).intersection(
set(random_term_ancs))) / len(anc) # fraction of ancestors in common
random_pairs.append((term, random_other_term, pair_shared_anc))

pass

@main.group()
def view():
"Virtual store/wrapper"
Expand Down

0 comments on commit 7ada397

Please sign in to comment.