Source code for pydft_qmmm.system.selection_utils

"""A module containing helper functions for selection query parsing.

Attributes:
    FAST_SELECTORS: The system attribute name and type indexed by the
        query attribute name for attributes that are expected to change
        every step of a simulation.
    SLOW_SELECTORS: The system attribute name and type indexed by the
        query attribute name for attributes that are not expected to
        change every step of a simulation.
    SELECT_KEYWORDS: The list of other keywords over which the keyword
        takes precedence and a function applying the keyword indexed by
        the keyword.
    FAST_VARIABLES: The system attribute name and a column index
        indexed by variable name for variables that are expected to
        change every step of a simulation.
    SLOW_VARIABLES: The system attribute name and a column index
        indexed by variable name for variables that are not expected to
        change every step of a simulation.
    OPERATORS: The list of other operators over which the operator
        takes precedence and a function applying the operator indexed
        by the operator's symbol.
    FUNCTIONS: Functions indexed by their possible names in the query.
    SELECTORS: The system attribute name and type indexed by the
        query attribute name for all attributes.
    VARIABLES: The system attribute name and a column index indexed
        by variable name for all variables.
    FAST_KEYWORDS: The system attribute name indexed by variable name
        or query attribute name for variables or selectors that are
        expected to change every step of a simulation.
    SLOW_KEYWORDS: The system attribute name indexed by variable name
        or query attribute name for variables or selectors that are
        not expected to change every step of a simulation.
    MATH_KEYWORDS: A set of variable names, operator symbols, and
        function names.
    KEYWORDS: A set of selection keywords and math keywords.
"""
from __future__ import annotations

__all__ = [
    "decompose",
    "interpret",
]

import operator
import re
from typing import TYPE_CHECKING

import numpy as np

from pydft_qmmm.utils import Subsystem

if TYPE_CHECKING:
    from pydft_qmmm import System
    from numpy.typing import NDArray

FAST_SELECTORS = {
    "subsystem": ("subsystems", Subsystem),
}

SLOW_SELECTORS = {
    "element": ("elements", str),
    "atom": ("atoms", int),
    "index": ("atoms", int),
    "name": ("names", str),
    "residue": ("residues", int),
    "resid": ("residues", int),
    "resname": ("residue_names", str),
    "chain": ("chains", str),
}

SELECT_KEYWORDS = {
    "and": (["and", "or"], lambda sel, pred: sel & pred),
    "or": (["or"], lambda sel, pred: sel | pred),
}

FAST_VARIABLES = {
    "x": ("positions", (0,)),
    "y": ("positions", (1,)),
    "z": ("positions", (2,)),
    "vx": ("velocities", (0,)),
    "vy": ("velocities", (1,)),
    "vz": ("velocities", (2,)),
    "fx": ("forces", (0,)),
    "fy": ("forces", (1,)),
    "fz": ("forces", (2,)),
}

SLOW_VARIABLES: dict[str, tuple[str, tuple[int, ...]]] = {
    "mass": ("masses", ()),
    "charge": ("charges", ()),
}

OPERATORS = {
    "^": (["+", "-", "*", "/", "^", "=", ">=", "<=", ">", "<"], operator.pow),
    "*": (["+", "-", "*", "/", "=", ">=", "<=", ">", "<"], operator.mul),
    "/": (["+", "-", "*", "/", "=", ">=", "<=", ">", "<"], operator.truediv),
    "+": (["+", "-", "=", ">=", "<=", ">", "<"], operator.add),
    "-": (["+", "-", "=", ">=", "<=", ">", "<"], operator.sub),
    "=": ([], operator.eq),
    ">=": ([], operator.ge),
    "<=": ([], operator.le),
    ">": ([], operator.gt),
    "<": ([], operator.lt),
}

FUNCTIONS = {
    "sqrt": np.sqrt,
    "sqr": np.sqrt,
    "abs": np.abs,
}

SELECTORS = FAST_SELECTORS | SLOW_SELECTORS
VARIABLES = FAST_VARIABLES | SLOW_VARIABLES
FAST_KEYWORDS = FAST_SELECTORS | FAST_VARIABLES
SLOW_KEYWORDS = SLOW_SELECTORS | SLOW_VARIABLES
MATH_KEYWORDS = VARIABLES.keys() | OPERATORS.keys() | FUNCTIONS.keys()
KEYWORDS = (MATH_KEYWORDS | SELECT_KEYWORDS.keys()
            | {"(", ")", "not", "within", "of", "same", "as"})


[docs] def isvalue(text: str) -> bool: """Determine if a string is a numerical value. Args: text: The string to evaluate. Returns: Whether or not the text is a numerical value. """ if text.count(".") > 1: return False numbers = [a.isnumeric() for a in text.split(".") if a] if len(numbers) == 0: return False return all(numbers)
[docs] def decompose(text: str) -> list[str]: """Decompose an atom selection query into meaningful components. Args: text: The atom selection query. Returns: The atom selection query broken into meaningful parts, demarcated by keywords. """ criteria = (r"(not| or | and |\(|\)|within| of |same| as " + r"".join([rf"|\{x}" for x in OPERATORS.keys()]) + r")") line = [a.strip() for a in re.split(criteria, text)] while "" in line: line.remove("") return line
[docs] def line_slice( line: list[str], start: int, low_priority: list[str] = [], ) -> slice: """Find the slice of a query within parentheses. Args: line: The atom selection query, broken into meaningful components. start: The index of the line where the statement within parentheses begins. low_priority: Strings which have a lower operator precedence. Returns: The slice whose start and stop corresponds to the phrase contained by parentheses. """ count_dict = {"(": 1, ")": -1} flag = True count = count_dict.get(line[start], 0) index = start + 1 while flag and index < len(line): count += count_dict.get(line[index], 0) if (count == 0 and line[index] in low_priority # This allows precedence of unary operators. and index > start + 1): flag = False else: index += 1 if count > 0: raise ValueError("Unclosed parenthesis in atom selection query") return slice(start + 1, index)
[docs] def evaluate(text: str, system: System) -> frozenset[int]: """Evaluate a part of an atom selection query. Args: text: A single contained statement from an atom selection query. system: The system whose atoms will be selected by evaluating a single query statement. Returns: The set of atom indices selected by the query statement. """ line = text.split(" ") category = SELECTORS[line[0].lower()] if " ".join(line).lower().startswith("atom name"): category = SELECTORS["name"] del line[1] elif " ".join(line).lower().startswith("residue name"): category = SELECTORS["resname"] del line[1] ret: frozenset[int] = frozenset({}) if category[0] == "atoms": for string in line[1:]: value = category[1](string) ret = ret | frozenset({value}) else: population = getattr(system, category[0]) for string in line[1:]: value = category[1](string) indices = {i for i, x in enumerate(population) if x == value} ret = ret | frozenset(indices) return ret
[docs] def evaluate_math(line: list[str], system: System) -> NDArray[np.float64]: """Evaluate strings corresponding to a mathematical expression. Args: line: The strings to evaluate. system: The system whose attributes will figure into the mathematical evaluation. Returns: The result of evaluating the mathematical expression in the context of the given system. """ value = np.zeros((len(system),)) count = 0 entry = line[count] if entry.lower() in VARIABLES: var = VARIABLES[entry.lower()] value += getattr(system, var[0])[:, *var[1]] elif isvalue(entry): value += float(entry) elif entry.split(" ")[0].lower() in SELECTORS: entry = entry.split(" ")[0] category = SELECTORS[entry.lower()] if category[0] == "atoms": value += np.array([i for i in range(len(system))]) elif category[0] == "residues": value += system.residues else: raise TypeError elif entry in KEYWORDS: if entry in FUNCTIONS: indices = line_slice(line, count, list(OPERATORS.keys())) predicate = evaluate_math(line[indices], system) value += FUNCTIONS[entry](predicate) elif entry in ["+", "-"]: indices = line_slice(line, count, ["+", "-", "*", "/", "^"]) predicate = evaluate_math(line[indices], system) value += predicate if entry == "+" else -predicate elif entry == "(": indices = line_slice(line, count, [")"]) value += evaluate_math(line[indices], system) else: raise ValueError( ("Two incompatable math operators have been placed " "next to each other in a query."), ) count = indices.stop - 1 else: raise ValueError(f"{entry=}") while count < len(line) - 1: count += 1 entry = line[count] if entry in OPERATORS: operator = OPERATORS[entry] indices = line_slice(line, count, operator[0]) predicate = evaluate_math(line[indices], system) value = operator[1](value, predicate) count = indices.stop - 1 elif entry not in KEYWORDS: raise ValueError(f"{entry=}") return value
[docs] def interpret(line: list[str], system: System) -> frozenset[int]: """Interpret a line of atom selection query language. This has been written to follow `VMD selection language`_. Args: line: The atom selection query, broken into meaningful components. system: The system whose atoms will be selected by interpreting the selection query. Returns: The set of atom indices selected by the query. """ selection: frozenset[int] = frozenset({}) count = 0 entry = line[count] if entry.split(" ")[0].lower() in SELECTORS: indices = line_slice(line, count - 1, ["and", "or"]) if any([x in MATH_KEYWORDS for x in line[indices]]): indices = slice(0, indices.stop) selection |= set( np.where(evaluate_math(line[indices], system))[0], ) else: selection = selection | evaluate(entry, system) count = indices.stop - 1 elif entry == "all": selection = selection | frozenset(range(len(system))) elif entry == "none": selection = selection | frozenset({}) elif entry in KEYWORDS or isvalue(entry): if entry == "(": indices = line_slice(line, count, [")"]) if all([isvalue(x) or x in MATH_KEYWORDS for x in line[indices]]): indices = line_slice(line, count, ["and", "or"]) indices = slice(0, indices.stop) selection |= set( np.where(evaluate_math(line[indices], system))[0], ) else: selection |= interpret(line[indices], system) elif entry == "not": indices = line_slice(line, count, ["and", "or"]) new_selection = (frozenset(range(len(system))) - interpret(line[indices], system)) selection |= new_selection elif entry == "within": # This does not currently support PBC, as in VMD. indices = line_slice(line, count, ["of"]) # This needs testing, the original is provided below # radius = evaluate_math(line[indices], [0])[0] radius = evaluate_math(line[indices], system)[0] atoms = interpret(line[indices.stop+1:], system) measure = np.min( np.linalg.norm( (system.positions.base[:, np.newaxis, :] - system.positions[sorted(atoms), :]), axis=2, ), axis=1, ) selection = selection | set(np.where(measure < radius)[0]) indices = line_slice(line, count) elif entry == "same": attribute = line[count+1] atoms = interpret(line[count+3:], system) if attribute.lower() in SELECTORS: text = attribute.split(" ") category = SELECTORS[text[0].lower()] if " ".join(text).lower().startswith("atom name"): category = SELECTORS["name"] elif " ".join(text).lower().startswith("residue name"): category = SELECTORS["resname"] if category[0] != "atoms": population = getattr(system, category[0]) atoms = frozenset( {i for i, x in enumerate(population) if x in population[sorted(atoms)]}, ) elif attribute.lower() in VARIABLES: var = VARIABLES[attribute.lower()] value = getattr(system, var[0])[:, *var[1]] atoms = frozenset( {i for i, x in enumerate(value) if x in value[sorted(atoms)]}, ) else: raise ValueError(f"Unrecognized attribute '{attribute}'") selection = selection | atoms indices = line_slice(line, count) elif entry in MATH_KEYWORDS or isvalue(entry): indices = line_slice(line, count, ["and", "or"]) indices = slice(0, indices.stop) selection |= set( np.where(evaluate_math(line[indices], system))[0], ) count = indices.stop - 1 else: raise ValueError(f"{entry=}") while count < len(line) - 1: count += 1 entry = line[count] if entry in SELECT_KEYWORDS: keyword = SELECT_KEYWORDS[entry] indices = line_slice(line, count, keyword[0]) predicate = interpret(line[indices], system) selection = keyword[1](selection, predicate) count = indices.stop - 1 elif entry not in KEYWORDS: raise ValueError(f"{entry=}") return selection