8000 Develop a direct basis construction for permutation and lattice translation by sekocha · Pull Request #125 · symfc/symfc · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Develop a direct basis construction for permutation and lattice translation #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/symfc/basis_sets/basis_sets_O2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
compressed_projector_sum_rules_O2,
projector_permutation_lat_trans_O2,
)
from symfc.utils.permutation_tools_O2 import compr_permutation_lat_trans_O2
from symfc.utils.rotation_tools_O2 import complementary_compr_projector_rot_sum_rules_O2
from symfc.utils.utils import SymfcAtoms
from symfc.utils.utils_O2 import (
Expand Down Expand Up @@ -132,15 +133,25 @@ def atomic_decompr_idx(self) -> np.ndarray:
def run(self, rotational_sum_rules: bool = False) -> FCBasisSetO2:
"""Compute compressed force constants basis set."""
trans_perms = self._spg_reps.translation_permutations
proj_pt = projector_permutation_lat_trans_O2(
trans_perms,
atomic_decompr_idx=self._atomic_decompr_idx,
fc_cutoff=self._fc_cutoff,
use_mkl=self._use_mkl,
verbose=self._log_level > 0,
)

c_pt = eigsh_projector(proj_pt, verbose=self._log_level > 0)
direct_permutation = True
if direct_permutation:
c_pt = compr_permutation_lat_trans_O2(
trans_perms,
atomic_decompr_idx=self._atomic_decompr_idx,
fc_cutoff=self._fc_cutoff,
verbose=self._log_level > 0,
)
else:
proj_pt = projector_permutation_lat_trans_O2(
trans_perms,
atomic_decompr_idx=self._atomic_decompr_idx,
fc_cutoff=self._fc_cutoff,
use_mkl=self._use_mkl,
verbose=self._log_level > 0,
)
c_pt = eigsh_projector(proj_pt, verbose=self._log_level > 0)

proj_rpt = get_compr_coset_projector_O2(
self._spg_reps,
fc_cutoff=self._fc_cutoff,
Expand Down
38 changes: 20 additions & 18 deletions src/symfc/basis_sets/basis_sets_O3.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,22 +139,16 @@ def run(self) -> FCBasisSetO3:
"""Compute compressed force constants basis set."""
trans_perms = self._spg_reps.translation_permutations

if self._fc_cutoff is not None:
direct_permutation = False
else:
direct_permutation = True

direct_permutation = True
tt0 = time.time()
if direct_permutation:
tt0 = time.time()
tt1 = time.time()
c_pt = compr_permutation_lat_trans_O3(
trans_perms,
atomic_decompr_idx=self._atomic_decompr_idx,
fc_cutoff=self._fc_cutoff,
verbose=self._log_level > 0,
)
else:
tt0 = time.time()
proj_pt = projector_permutation_lat_trans_O3(
trans_perms,
atomic_decompr_idx=self._atomic_decompr_idx,
Expand Down Expand Up @@ -203,16 +197,24 @@ def run(self) -> FCBasisSetO3:
tt7 = time.time()

if self._log_level:
print(
"Time (proj(perm @ lattice trans.) :",
"{:.3f}".format(tt1 - tt0),
flush=True,
)
print(
"Time (eigh(perm @ ltrans)) :",
"{:.3f}".format(tt2 - tt1),
flush=True,
)
if direct_permutation:
print(
"Time (perm @ ltrans) :",
"{:.3f}".format(tt2 - tt0),
flush=True,
)

else:
print(
"Time (proj(perm @ lattice trans.) :",
"{:.3f}".format(tt1 - tt0),
flush=True,
)
print(
"Time (eigh(perm @ ltrans)) :",
"{:.3f}".format(tt2 - tt1),
flush=True,
)
print(
"Time (coset) :",
"{:.3f}".format(tt3 - tt2),
Expand Down
14 changes: 13 additions & 1 deletion src/symfc/utils/matrix_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def get_entire_combinations(n, r):
return combs.T


def get_combinations(natom: int, order: int, fc_cutoff: Optional[FCCutoff] = None):
def get_combinations(
natom: int,
order: int,
fc_cutoff: Optional[FCCutoff] = None,
indep_atoms: Optional[np.ndarray] = None,
):
"""Return numpy array of FC index combinations."""
if fc_cutoff is None:
combinations = get_entire_combinations(3 * natom, order)
Expand All @@ -43,6 +48,13 @@ def get_combinations(natom: int, order: int, fc_cutoff: Optional[FCCutoff] = Non
raise NotImplementedError(
"Combinations are implemented only for 2 <= order <= 4."
)

if indep_atoms is not None:
nonzero = np.zeros(combinations.shape[0], dtype=bool)
atom_indices = combinations[:, 0] // 3
for i in indep_atoms:
nonzero[atom_indices == i] = True
combinations = combinations[nonzero]
return combinations


Expand Down
35 changes: 35 additions & 0 deletions src/symfc/utils/permutation_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Permutation utility functions."""

import numpy as np
import scipy
from scipy.sparse import csr_array


def construct_basis_from_orbits(orbits: np.ndarray):
"""Transform orbits into basis matrix."""
size_full = len(orbits)
nonzero = orbits != -1
if not np.all(nonzero):
orbits = orbits[nonzero]
nonzero_map = np.ones(size_full, dtype="int") * -1
nonzero_map[nonzero] = np.arange(len(orbits))
orbits = nonzero_map[orbits]

size1 = len(orbits)
orbits = csr_array(
(np.ones(size1, dtype=bool), (np.arange(size1), orbits)),
shape=(size1, size1),
dtype=bool,
)

n_col, cols = scipy.sparse.csgraph.connected_components(orbits)
key, cnt = np.unique(cols, return_counts=True)
values = np.reciprocal(np.sqrt(cnt))

rows = np.where(nonzero)[0]
c_pt = csr_array(
(values[cols], (rows, cols)),
shape=(size_full, n_col),
dtype="double",
)
return c_pt
120 changes: 120 additions & 0 deletions src/symfc/utils/permutation_tools_O2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Permutation utility functions for 3rd order force constants."""

from typing import Optional

import numpy as np
from scipy.sparse import csr_array

from symfc.utils.cutoff_tools import FCCutoff
from symfc.utils.matrix_tools import get_combinations
from symfc.utils.permutation_tools import construct_basis_from_orbits
from symfc.utils.solver_funcs import get_batch_slice
from symfc.utils.utils import get_indep_atoms_by_lat_trans
from symfc.utils.utils_O2 import _get_atomic_lat_trans_decompr_indices


def _N3N3_to_NNand33(combs: np.ndarray, N: int) -> np.ndarray:
"""Transform index order."""
vecNN, vec33 = np.divmod(combs[:, 0], 3)
vecNN *= N
vec33 *= 3
div, mod = np.divmod(combs[:, 1], 3)
vecNN += div
vec33 += mod
return vecNN, vec33


def compr_permutation_lat_trans_O2(
trans_perms: np.ndarray,
atomic_decompr_idx: Optional[np.ndarray] = None,
fc_cutoff: Optional[FCCutoff] = None,
n_batch: Optional[int] = None,
verbose: bool = False,
) -> csr_array:
"""Build a compression matrix for permutation rules compressed by C_trans.

This calculates C_(trans,perm) without allocating C_trans and C_perm.
Batch calculations are used to reduce memory allocation.

Parameters
----------
trans_perms : ndarray
Permutation of atomic indices by lattice translational symmetry.
dtype='intc'.
shape=(n_l, N), where n_l and N are the numbers of lattce points and
atoms in supercell.
fc_cutoff : FCCutoff class object. Default is None.

Return
------
Compressed basis matrix for permutation
C_pt = eigh(C_trans.T @ C_perm @ C_perm.T @ C_trans)
"""
n_lp, natom = trans_perms.shape
NN9 = natom**2 * 9
if atomic_decompr_idx is None:
atomic_decompr_idx = _get_atomic_lat_trans_decompr_indices(trans_perms)

orbits = np.ones(NN9 // n_lp, dtype="int") * -1
indep_atoms = get_indep_atoms_by_lat_trans(trans_perms)

# order = 1
combinations = np.array([[i, i] for i in range(3 * natom)], dtype=int)
perms = [[0, 0]]
orbits = _update_orbits_from_combinations(
combinations,
perms,
atomic_decompr_idx,
trans_perms,
orbits,
n_perms_group=1,
n_batch=1,
verbose=verbose,
)

# order = 2
combinations = get_combinations(
natom, order=2, fc_cutoff=fc_cutoff, indep_atoms=indep_atoms
)
perms = [[0, 1], [1, 0]]
orbits = _update_orbits_from_combinations(
combinations,
perms,
atomic_decompr_idx,
trans_perms,
orbits,
n_perms_group=1,
n_batch=1,
verbose=verbose,
)
if verbose:
print("Construct basis matrix for permutations", flush=True)
c_pt = construct_basis_from_orbits(orbits)
return c_pt


def _update_orbits_from_combinations(
combinations: np.ndarray,
permutations: np.ndarray,
atomic_decompr_idx: np.ndarray,
trans_perms: np.ndarray,
orbits: np.ndarray,
n_perms_group: int = 1,
n_batch: int = 1,
verbose: bool = False,
) -> csr_array:
"""Construct projector of permutation and lattice translation."""
n_lp, natom = trans_perms.shape
n_comb = combinations.shape[0]
n_perms = len(permutations)
n_perms_sym = n_perms // n_perms_group
for begin, end in zip(*get_batch_slice(n_comb, n_comb // n_batch)):
if verbose:
print("Permutation basis:", str(end) + "/" + str(n_comb), flush=True)
combs_perm = combinations[begin:end][:, permutations].reshape((-1, 2))
combs_perm, combs33 = _N3N3_to_NNand33(combs_perm, natom)
cols = atomic_decompr_idx[combs_perm] * 9 + combs33
cols = cols.reshape(-1, n_perms_sym)
for c in cols.T:
orbits[c] = cols[:, 0]
return orbits
Loading
Loading
0