Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docstrings in CreationCC wfn. #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 95 additions & 20 deletions fanpy/wfn/creation_cc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,70 @@
"""Creation Coupled-Cluster wavefunction."""
import numpy as np
from itertools import combinations

from fanpy.wfn.base import BaseWavefunction
from fanpy.tools import math_tools, slater
from fanpy.tools import slater


class CreationCC(BaseWavefunction):
r"""Creation Coupled-Cluster wavefunction.

The creation CC wavefunction is given by

.. math::
|\Psi \rangle = \exp(\sum_{ij}c_{ij}a^{\dagger}_i a^{\dagger}_j) | 0 \rangle

where :math: | 0 \rangle is the vacuum state, :math: a^{\dagger}_i is the creation operator
for the i-th spin orbital, and :math: c_{ij} are the parameters of the wavefunction.

Attributes
----------
nelec : int
Number of electrons.
nspin : int
Number of spin orbitals.
memory : {float, int, str, None}
Memory available for the wavefunction.
dict_orbpair_ind : dict of 2-tuple of int: int
Dictionary that maps orbital pairs to column indices.
dict_ind_orbpair : dict of int: 2-tuple of int
Dictionary that maps column indices to orbital pairs.
params : np.ndarray
Parameters of the wavefunction.
permutations : list of list of 2-tuple of int
Permutations of the orbital pairs.
signs : list of int
Signs of the permutations.

Methods
-------
__init__(nelec, nspin, memory=None, orbpairs=None, params=None)
Initialize the wavefunction.
assign_nelec(nelec)
Assign the number of electrons.
assign_params(params=None, add_noise=False)
Assign the parameters of the wavefunction.
assign_orbpairs(orbpairs=None)
Assign the orbital pairs used to construct the wavefunction.
get_col_ind(orbpair)
Get the column index that corresponds to the given orbital pair.
get_permutations()
Get the permutations of the given indices.
get_sign(indices)
Get the sign of the permutation of the given indices.
_olp(sd)
Calculate overlap with Slater determinant.
_olp_deriv(sd)
Calculate the derivative of the overlap.
get_overlap(sd, deriv=None) : {float, np.ndarray}
Return the (derivative) overlap of the wavefunction with a Slater determinant.
calculate_product(occ_indices, permutation, sign)
Calculate the product of the parameters of the given permutation.
"""

def __init__(self, nelec, nspin, memory=None, orbpairs=None, params=None):
""" Initialize the wavefunction
""" Initialize the wavefunction.

Parameters
----------
nelec : int
Expand All @@ -30,7 +87,7 @@ def __init__(self, nelec, nspin, memory=None, orbpairs=None, params=None):
self.assign_params(params=params)
self.permutations, self.signs = self.get_permutations()

def assign_nelec(self, nelec: int):
def assign_nelec(self, nelec):
"""Assign the number of electrons.

Parameters
Expand All @@ -55,6 +112,7 @@ def assign_nelec(self, nelec: int):

def assign_params(self, params=None, add_noise=False):
"""Assign the parameters of the creation cc wfn.

Parameters
----------
params : {np.ndarray, None}
Expand All @@ -80,7 +138,7 @@ def assign_params(self, params=None, add_noise=False):
super().assign_params(params=params, add_noise=add_noise)

def assign_orbpairs(self, orbpairs=None):
"""Assign the orbital pairs.
"""Assign the orbital pairs used to construct the creation CC wavefunction.

Parameters
----------
Expand Down Expand Up @@ -132,7 +190,7 @@ def assign_orbpairs(self, orbpairs=None):
self.dict_orbpair_ind = dict_orbpair_ind
self.dict_ind_orbpair = {i: orbpair for orbpair, i in dict_orbpair_ind.items()}

def get_col_ind(self, orbpair: tuple[int]):
def get_col_ind(self, orbpair):
"""Get the column index that corresponds to the given orbital pair.

Parameters
Expand Down Expand Up @@ -161,19 +219,16 @@ def get_col_ind(self, orbpair: tuple[int]):
)

def get_permutations(self):
"""Get the permutations of the given indices.

Parameters
----------
indices : list of int
Indices of the orbitals.
""" Calculate the permutations of indices 0 to nelec.

Returns
-------
perms : list of list of int
Permutations of the given indices.

Permutations of the indices (0 to nelec).
signs : list of int
Signs of the permutations.
"""

indices = np.arange(self.nelec, dtype=int)
perm_list = list(combinations(indices, r=2))

Expand All @@ -188,7 +243,7 @@ def get_permutations(self):
signs.append(self.get_sign(element))
return perms, signs

def get_sign(self, indices: list[int]):
def get_sign(self, indices):
"""Get the sign of the permutation of the given indices.

Parameters
Expand All @@ -210,7 +265,7 @@ def get_sign(self, indices: list[int]):
sign *= -1
return sign

def _olp(self, sd: int):
def _olp(self, sd):
""" Calculate overlap with Slater determinant.

Parameters
Expand All @@ -230,12 +285,29 @@ def _olp(self, sd: int):
return olp

def calculate_product(self, occ_indices, permutation, sign):
"""Calculate the product of the parameters of the given permutation.

Parameters
----------
occ_indices : list of int
Occupation indices of the Slater determinant.
permutation : list of 2-tuple of int
Permutation of the orbital pairs.
sign : int
Sign of the permutation.

Returns
-------
prod : float
Product of the parameters of the given permutation
"""

col_inds = list(map(self.get_col_ind, occ_indices.take(permutation)))
prod = sign*np.prod(self.params[col_inds])
return prod

def _olp_deriv(self, sd: int):
""" Calculate the derivative of the overlap
def _olp_deriv(self, sd):
""" Calculate the derivative of the overlap with a Slater determinant.

Parameters
----------
Expand All @@ -261,9 +333,11 @@ def _olp_deriv(self, sd: int):
return output


def get_overlap(self, sd: int, deriv=None):
"""Return the overlap of the wavefunction with a Slater determinant.
Inlcude math later.
def get_overlap(self, sd, deriv=None):
r"""Return the (derivative) overlap of the wavefunction with a Slater determinant.

.. math::
| \Psi \rangle = \sum_{\textbf{m} \in S} \sum_{\{i_1 j_1, ..., i_{n_m} j_{n_m} \} = \textbf{m}} sgn (\sigma(\{i_1 j_1, ..., i_{n_m} j_{n_m} \}))\prod_{k}^{n_m} c_{i_k j_k} | \textbf{m} \rangle

Parameters
----------
Expand All @@ -278,6 +352,7 @@ def get_overlap(self, sd: int, deriv=None):
overlap : {float, np.ndarray}
Overlap (or derivative of the overlap) of the wavefunction with the given Slater
determinant.

"""

if deriv is None:
Expand Down