Skip to content

Commit

Permalink
Merge pull request #23 from adam-kral/return-output
Browse files Browse the repository at this point in the history
Return output in `main` function in electrostatic and hydrophobic script
  • Loading branch information
vhoer authored Mar 11, 2024
2 parents d7e9bb1 + 551b792 commit c7f0fc7
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 117 deletions.
181 changes: 118 additions & 63 deletions surface_analyses/commandline_electrostatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import argparse
import csv
from datetime import datetime
import datetime
from collections import namedtuple
import os
import pathlib
Expand All @@ -16,6 +16,7 @@
from scipy.spatial import cKDTree
from gisttools.gist import load_dx
from mdtraj.core.element import carbon, nitrogen, oxygen, sulfur
import mdtraj as md

from .patches import assign_patches
from .surface import Surface
Expand All @@ -32,10 +33,23 @@
sulfur: 1.8,
}

def main(argv=None):
if argv is None:
argv = sys.argv[1:]
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

def main(args=None):
print(f'pep_patch_electrostatic starting at {datetime.datetime.now()}')
print('Command line arguments:')
print(' '.join(args or sys.argv))
args = parse_args(args)
traj = load_trajectory_using_commandline_args(args)
# trajectory-related arguments are not passed to run_electrostatics
del args.parm, args.trajs, args.stride, args.ref, args.protein_ref
run_electrostatics(traj, **vars(args))


def parse_args(argv=None):
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
fromfile_prefix_chars='@',
)
add_trajectory_options_to_parser(parser)
parser.add_argument('--dx', type=str, default=None, nargs='?', help="Optional dx file with the electrostatic potential. If this is omitted, you must specify --apbs_dir")
parser.add_argument('--apbs_dir', help="Directory in which intermediate files are stored when running APBS. Will be created if it does not exist.", type=str, default=None)
Expand All @@ -45,14 +59,14 @@ def main(argv=None):
'-c', '--patch_cutoff',
type=float,
nargs=2,
default=[2, -2],
default=(2., -2.),
help='Cutoff for positive and negative patches.'
)
parser.add_argument(
'-ic', '--integral_cutoff',
type=float,
nargs=2,
default=[0.3, -0.3],
default=(0.3, -0.3),
help='Cutoffs for "high" and "low" integrals.'
)
parser.add_argument(
Expand Down Expand Up @@ -106,7 +120,7 @@ def main(argv=None):
parser.add_argument(
'-s','--size_cutoff',
type=float,
default=0,
default=0.,
help='Restrict output to patches with an area of over s A^2. If s = 0, no cutoff is applied (default).',
)
parser.add_argument('--gauss_shift', type=float, default=0.1)
Expand All @@ -125,39 +139,62 @@ def main(argv=None):
help="Specify ion species and their properties (charge, concentration, and radius). "
"Provide values for multiple ion species as charge1, conc1, radius1, charge2, conc2, radius2, etc."
)

args = parser.parse_args(argv)

print(f'pep_patch_electrostatic, {datetime.now().strftime("%m/%d/%Y, %H:%M:%S")}\n')
print('Command line arguments:')
print(' '.join(sys.argv))
if args.out is None:
return parser.parse_args(argv)


def run_electrostatics(
traj: md.Trajectory,
dx: str = None,
apbs_dir: str = None,
probe_radius: float = 1.4,
out: str = None,
patch_cutoff: tuple = (2., -2.),
integral_cutoff: tuple = (0.3, -0.3),
surface_type: str = "sas",
ply_out: str = None,
pos_patch_cmap: str = 'tab20c',
neg_patch_cmap: str = 'tab20c',
ply_cmap: str = 'coolwarm_r',
ply_clim: tuple = None,
check_cdrs: bool = False,
n_patches: int = 0,
size_cutoff: float = 0.,
gauss_shift: float = 0.1,
gauss_scale: float = 1.0,
pH: float = None,
ion_species: tuple = None,
):
f"""Python interface for the functionality of pep_patch_electrostatic
The first argument is a single-frame mdtraj Trajectory.
The other arguments are identical to those of the commandline interface.
"""
if out is None:
csv_outfile = sys.stdout
else:
csv_outfile = open(args.out, "w")
csv_outfile = open(out, "w")

ion_species = get_ion_species(args)
traj = load_trajectory_using_commandline_args(args)
ion_species = get_ion_species(ion_species)
# Renumber residues, takes care of insertion codes in PDB residue numbers
for i, res in enumerate(traj.top.residues,start=1):
res.resSeq = i
if args.dx is None and args.apbs_dir is None:

if dx is None and apbs_dir is None:
raise ValueError("Either DX or APBS_DIR must be specified.")

if args.dx is not None and args.apbs_dir is not None:
if dx is not None and apbs_dir is not None:
warnings.warn("Warning: both DX and APBS_DIR are specified. Will not run APBS "
"and use the dx file instead.")

if traj.n_frames != 1:
raise ValueError("The electrostatics script only works with a single-frame trajectory.")

if args.dx is not None:
griddata = load_dx(args.dx, colname='DX')
if dx is not None:
griddata = load_dx(dx, colname='DX')
griddata.struct = traj[0]
else:
griddata = get_apbs_potential_from_mdtraj(traj, args.apbs_dir, args.pH, ion_species)
griddata = get_apbs_potential_from_mdtraj(traj, apbs_dir, pH, ion_species)

# *10 because mdtraj calculates stuff in nanometers instead of Angstrom.
radii = 10. * np.array([atom.element.radius for atom in traj.top.atoms])
columns = ['DX']
Expand All @@ -166,21 +203,21 @@ def main(argv=None):
pprint.pprint({
'#Atoms': traj.n_atoms,
'Grid dimensions': griddata.grid.shape,
**vars(args),
# **kwargs,
})

print('Calculating triangulated SASA')

if args.surface_type == 'sas':
surf = compute_sas(griddata.grid, griddata.coord, radii, args.probe_radius)
elif args.surface_type == 'gauss':
surf = compute_gauss_surf(griddata.grid, griddata.coord, radii, args.gauss_shift, args.gauss_scale)
elif args.surface_type == 'ses':
surf = compute_ses(griddata.grid, griddata.coord, radii, args.probe_radius)
if surface_type == 'sas':
surf = compute_sas(griddata.grid, griddata.coord, radii, probe_radius)
elif surface_type == 'gauss':
surf = compute_gauss_surf(griddata.grid, griddata.coord, radii, gauss_shift, gauss_scale)
elif surface_type == 'ses':
surf = compute_ses(griddata.grid, griddata.coord, radii, probe_radius)
else:
raise ValueError("Unknown surface type: " + str(args.surface_type))
raise ValueError("Unknown surface type: " + str(surface_type))

if args.check_cdrs:
if check_cdrs:
try:
from .anarci_wrapper.annotation import Annotation
cdrs = [
Expand Down Expand Up @@ -210,86 +247,104 @@ def main(argv=None):
print('Finding patches')
values = griddata.interpolate(columns, surf.vertices)[columns[0]]
patches = pd.DataFrame({
'positive': assign_patches(surf.faces, values > args.patch_cutoff[0]),
'negative': assign_patches(surf.faces, values < args.patch_cutoff[1]),
'positive': assign_patches(surf.faces, values > patch_cutoff[0]),
'negative': assign_patches(surf.faces, values < patch_cutoff[1]),
'area': vert_areas,
'atom': closest_atom,
'residue': residues[closest_atom],
'cdr': np.isin(residues, cdrs)[closest_atom]
'cdr': np.isin(residues, cdrs)[closest_atom],
'value': values,
})


# save values and atom in surf for consistency with commandline_hydrophobic
surf['values'] = values
surf['atom'] = closest_atom

#keep args.n_patches largest patches (n > 0) or smallest patches (n < 0) or patches with an area over the size cutoff
if args.n_patches != 0 or args.size_cutoff != 0:
if n_patches != 0 or size_cutoff != 0:
#interesting interaction: setting a -n cutoff and size cutoff should yield the n smallest patches with an area over the size cutoff
replace_vals = {}
for patch_type in ('negative', 'positive'):
# sort patches by area and filter top n patches (or bottom n patches for n < 0)
# also we apply the size cutoff here. It defaults to 0, so should not do anything if not explicitly set as all areas should be > 0.
area = patches.query(f'{patch_type} != -1').groupby(f'{patch_type}').sum(numeric_only=True)['area']
order = (area[area > args.size_cutoff] # discard patches with an area under size cutoff
order = (area[area > size_cutoff] # discard patches with an area under size cutoff
.sort_values(ascending=False).index) # ... and sort them
filtered = order[:args.n_patches] if args.n_patches > 0 else order[args.n_patches:]

filtered = order[:n_patches] if n_patches > 0 else order[n_patches:]
# set patches not in filtered to -1
patches.loc[~patches[patch_type].isin(filtered), patch_type] = -1
patches.loc[~patches[patch_type].isin(filtered), patch_type] = -1

# build replacement dict to renumber patches in df according to size
order_map = {elem : i for i, elem in enumerate(order[:args.n_patches]) }
replace_vals[patch_type] = order_map
order_map = {elem: i for i, elem in enumerate(order[:n_patches])}
replace_vals[patch_type] = order_map
patches.replace(replace_vals, inplace=True)

output = csv.writer(csv_outfile)
output.writerow(['type', 'npoints', 'area', 'cdr', 'main_residue'])
write_patches(patches, output, 'positive')
write_patches(patches, output, 'negative')

# Compute the total solvent-accessible potential.
within_range, closest_atom, distance = griddata.distance_to_spheres(rmax=10, atomic_radii=radii)
not_protein = distance > args.probe_radius
not_protein = distance > probe_radius
accessible = within_range[not_protein]
voxel_volume = griddata.grid.voxel_volume
accessible_data = griddata[columns[0]].values[accessible]
integral = np.sum(accessible_data) * voxel_volume
integral_high = np.sum(np.maximum(accessible_data - args.integral_cutoff[0], 0)) * voxel_volume
integral_high = np.sum(np.maximum(accessible_data - integral_cutoff[0], 0)) * voxel_volume
integral_pos = np.sum(np.maximum(accessible_data, 0)) * voxel_volume
integral_neg = np.sum(np.minimum(accessible_data, 0)) * voxel_volume
integral_low = np.sum(np.minimum(accessible_data - args.integral_cutoff[1], 0)) * voxel_volume
integral_low = np.sum(np.minimum(accessible_data - integral_cutoff[1], 0)) * voxel_volume
print('Integrals (total, ++, +, -, --):')
print(f'{integral} {integral_high} {integral_pos} {integral_neg} {integral_low}')
if args.ply_out:

if ply_out:
pos_surf = Surface(surf.vertices, surf.faces)
pos_area = patches.query('positive != -1').groupby('positive').sum(numeric_only=True)['area']
pos_order = pos_area.sort_values(ascending=False).index
color_surface_by_group(pos_surf, patches['positive'].values, order=pos_order, cmap=args.pos_patch_cmap)
pos_surf.write_ply(args.ply_out + '-pos.ply')
color_surface_by_group(pos_surf, patches['positive'].values, order=pos_order, cmap=pos_patch_cmap)
pos_surf.write_ply(ply_out + '-pos.ply')

neg_surf = Surface(surf.vertices, surf.faces)
neg_area = patches.query('negative != -1').groupby('negative').sum(numeric_only=True)['area']
neg_order = neg_area.sort_values(ascending=False).index
color_surface_by_group(neg_surf, patches['negative'].values, order=neg_order, cmap=args.neg_patch_cmap)
neg_surf.write_ply(args.ply_out + '-neg.ply')
color_surface_by_group(neg_surf, patches['negative'].values, order=neg_order, cmap=neg_patch_cmap)
neg_surf.write_ply(ply_out + '-neg.ply')

potential_surf = Surface(surf.vertices, surf.faces)
potential_surf['values'] = values
color_surface(potential_surf, 'values', cmap=args.ply_cmap, clim=args.ply_clim)
potential_surf.write_ply(args.ply_out + '-potential.ply')
color_surface(potential_surf, 'values', cmap=ply_cmap, clim=ply_clim)
potential_surf.write_ply(ply_out + '-potential.ply')

# close user output file, but not stdout
if args.out is not None:
if out is not None:
csv_outfile.close()

return {
'surface': surf,
'integrals': {
'integral': integral,
'integral_high': integral_high,
'integral_pos': integral_pos,
'integral_neg': integral_neg,
'integral_low': integral_low,
},
'patch_vertices': patches,
}


IonSpecies = namedtuple("IonSpecies", "charge concentration radius")

DEFAULT_ION_SPECIES = [IonSpecies(1.0, 0.1, 2.0), IonSpecies(-1.0, 0.1, 2.0)]

def get_ion_species(commandline_arguments):
args = commandline_arguments.ion_species
if args is None:
if commandline_arguments is None:
return DEFAULT_ION_SPECIES
if len(args) % 3 != 0:
if len(commandline_arguments) % 3 != 0:
raise ValueError("Number of arguments for --ion_species must be divisible by 3.")
# important to keep this an iterator
args_it = (float(arg) for arg in args)
args_it = (float(arg) for arg in commandline_arguments)
species = []
for charge, conc, radius in zip(args_it, args_it, args_it):
species.append(IonSpecies(charge, conc, radius))
Expand Down
Loading

0 comments on commit c7f0fc7

Please sign in to comment.