Source code for pydft_qmmm.common.utils.selection_utils

"""A module containing helper functions accessed by multiple classes.

Attributes:
    SELECTORS: Pairs of VMD selection keywords and the corresponding
        attribute and type to check in a system.
"""
from __future__ import annotations

import re
from typing import TYPE_CHECKING

from ..constants import Subsystem

if TYPE_CHECKING:
    from pydft_qmmm import System


SELECTORS = {
    "element": ("elements", str),
    "atom": ("atoms", int),
    "index": ("atoms", int),
    "name": ("names", str),
    "residue": ("residues", int),
    "resid": ("residues", int),
    "resname": ("residue_names", str),
    "subsystem": ("subsystems", Subsystem),
}


[docs] def decompose(text: str) -> list[str]: """Decompose an atom selection query into meaningful components. Args: text: The atom selection query. Returns: The atom selection query broken into meaningful parts, demarcated by keywords. """ line = [a.strip() for a in re.split(r"(not|or|and|\(|\))", text)] while "" in line: line.remove("") return line
[docs] def evaluate(text: str, system: System) -> frozenset[int]: """Evaluate a part of an atom selection query. Args: text: A single contained statement from an atom selection query. system: The system whose atoms will be selected by evaluating a single query statement. Returns: The set of atom indices selected by the query statement. """ line = text.split(" ") category = SELECTORS[line[0].lower()] if " ".join(line).lower().startswith("atom name"): category = SELECTORS["name"] del line[1] elif " ".join(line).lower().startswith("residue name"): category = SELECTORS["resname"] del line[1] ret: frozenset[int] = frozenset({}) if category[0] == "atoms": for string in line[1:]: value = category[1](string) ret = ret | frozenset({value}) else: population = getattr(system, category[0]) for string in line[1:]: value = category[1](string) indices = {i for i, x in enumerate(population) if x == value} ret = ret | frozenset(indices) return ret
[docs] def parens_slice(line: list[str], start: int) -> slice: """Find the slice of a query within parentheses. Args: line: The atom selection query, broken into meaningful components. start: The index of the line where the statement within parentheses begins. Returns: The slice whose start and stop corresponds to the phrase contained by parentheses. """ flag = True count = 1 index = start while flag: if line[index] == "(": count += 1 if line[index] == ")": count -= 1 if count == 0: stop = index flag = False index += 1 return slice(start, stop)
[docs] def not_slice(line: list[str], start: int) -> slice: """Find the slice of a query modified by the 'not' keyword. Args: line: The atom selection query, broken into meaningful components. start: The index of the line where the statement modified by the 'not' keyword begins. Returns: The slice whose start and stop corresponds to the phrase modified by the 'not' keyword. """ flag = True count = 0 index = start while flag: if line[index] == "(": count += 1 if line[index] == ")": count -= 1 if count == 0: stop = index + 1 flag = False index += 1 return slice(start, stop)
[docs] def and_slice(line: list[str], start: int) -> slice: """Find the slice of a query modified by the 'and' keyword. Args: line: The atom selection query, broken into meaningful components. start: The index of the line where the statement modified by the 'and' keyword begins. Returns: The slice whose start and stop corresponds to the phrase modified by the 'and' keyword. """ flag = True count = 0 index = start while flag: if line[index] == "(": count += 1 if line[index] == ")": count -= 1 if count == 0 and line[index] != "not": stop = index + 1 flag = False index += 1 return slice(start, stop)
[docs] def or_slice(line: list[str], start: int) -> slice: """Find the slice of a query modified by the 'or' keyword. Args: line: The atom selection query, broken into meaningful components. start: The index of the line where the statement modified by the 'or' keyword begins. Returns: The slice whose start and stop corresponds to the phrase modified by the 'or' keyword. """ flag = True count = 0 index = start while flag: if line[index] == "(": count += 1 if line[index] == ")": count -= 1 if index < len(line) - 1: if line[index+1] == "and": count += 1 if index >= 1: if line[index-1] == "and": count -= 1 if count == 0 and line[index] != "not": stop = index + 1 flag = False index += 1 return slice(start, stop)
[docs] def interpret(line: list[str], system: System) -> frozenset[int]: """Interpret a line of atom selection query language. Args: line: The atom selection query, broken into meaningful components. system: The system whose atoms will be selected by interpreting the selection query. Returns: The set of atom indices selected by the query. .. note:: Based on the VMD atom selection rules. """ # Precedence: () > not > and > or full = frozenset(range(len(system))) selection: frozenset[int] = frozenset({}) count = 0 while count < len(line): entry = line[count] if entry == "all": selection = selection | full elif entry == "none": selection = selection | frozenset({}) elif entry.split(" ")[0].lower() in SELECTORS: selection = selection | evaluate(entry, system) elif entry == "(": indices = parens_slice(line, count + 1) selection = selection | interpret(line[indices], system) count = indices.stop elif entry == "not": indices = not_slice(line, count + 1) selection = selection | (full - interpret(line[indices], system)) count = indices.stop elif entry == "and": indices = and_slice(line, count + 1) selection = selection & interpret(line[indices], system) count = indices.stop elif entry == "or": indices = or_slice(line, count + 1) selection = selection | interpret(line[indices], system) count = indices.stop else: print(f"{entry = }") raise ValueError count += 1 return selection