Source code for qcportal.reaction.record_models

from __future__ import annotations

from typing import List, Union, Optional, Tuple, Iterable

try:
    from pydantic.v1 import BaseModel, Extra, root_validator, constr, PrivateAttr, Field
except ImportError:
    from pydantic import BaseModel, Extra, root_validator, constr, PrivateAttr, Field
from typing_extensions import Literal

from qcportal.molecules import Molecule
from qcportal.cache import get_records_with_cache
from qcportal.record_models import BaseRecord, RecordAddBodyBase, RecordQueryFilters
from ..optimization.record_models import OptimizationRecord, OptimizationSpecification
from ..singlepoint.record_models import (
    QCSpecification,
    SinglepointRecord,
)


class ReactionKeywords(BaseModel):
    # NOTE: If we add keywords, update the dataset additional_keywords tests and add extra = Extra.forbid.
    # The current setup is needed for those tests (to allow for testing additional_keywords)
    # is needed
    class Config:
        pass
        # extra = Extra.forbid


[docs] class ReactionSpecification(BaseModel):
[docs] class Config: extra = Extra.forbid
program: constr(to_lower=True) = "reaction" singlepoint_specification: Optional[QCSpecification] optimization_specification: Optional[OptimizationSpecification] keywords: ReactionKeywords
[docs] @root_validator def required_spec(cls, v): qc_spec = v.get("singlepoint_specification", None) opt_spec = v.get("optimization_specification", None) if qc_spec is None and opt_spec is None: raise ValueError("singlepoint_specification or optimization_specification must be specified") return v
class ReactionAddBody(RecordAddBodyBase): specification: ReactionSpecification stoichiometries: List[List[Tuple[float, Union[int, Molecule]]]] class ReactionQueryFilters(RecordQueryFilters): program: Optional[List[str]] = None qc_program: Optional[List[constr(to_lower=True)]] = None qc_method: Optional[List[constr(to_lower=True)]] = None qc_basis: Optional[List[Optional[constr(to_lower=True)]]] = None optimization_program: Optional[List[constr(to_lower=True)]] = None molecule_id: Optional[List[int]] = None class ReactionComponentMeta(BaseModel): class Config: extra = Extra.forbid molecule_id: int coefficient: float singlepoint_id: Optional[int] optimization_id: Optional[int] molecule: Optional[Molecule] class ReactionComponent(ReactionComponentMeta): singlepoint_record: Optional[SinglepointRecord] = None optimization_record: Optional[SinglepointRecord] = None
[docs] class ReactionRecord(BaseRecord): record_type: Literal["reaction"] = "reaction" specification: ReactionSpecification total_energy: Optional[float] ###################################################### # Fields not always included when fetching the record ###################################################### components_meta_: Optional[List[ReactionComponentMeta]] = Field(None, alias="components") ######################################## # Caches ######################################## _components: Optional[List[ReactionComponent]] = PrivateAttr(None)
[docs] def propagate_client(self, client): BaseRecord.propagate_client(self, client) if self._components is not None: for comp in self._components: if comp.singlepoint_record: comp.singlepoint_record.propagate_client(self._client) if comp.optimization_record: comp.optimization_record.propagate_client(self._client)
@classmethod def _fetch_children_multi( cls, client, record_cache, records: Iterable[ReactionRecord], include: Iterable[str], force_fetch: bool = False ): # Should be checked by the calling function assert records assert all(isinstance(x, ReactionRecord) for x in records) if "components" in include or "**" in include: # collect all singlepoint * optimization ids for all optimization sp_ids = set() opt_ids = set() for r in records: if r.components_meta_: for cm in r.components_meta_: if cm.singlepoint_id is not None: sp_ids.add(cm.singlepoint_id) if cm.optimization_id is not None: opt_ids.add(cm.optimization_id) sp_ids = list(sp_ids) opt_ids = list(opt_ids) sp_records = get_records_with_cache( client, record_cache, SinglepointRecord, sp_ids, include=include, force_fetch=force_fetch ) opt_records = get_records_with_cache( client, record_cache, OptimizationRecord, opt_ids, include=include, force_fetch=force_fetch ) sp_map = {r.id: r for r in sp_records} opt_map = {r.id: r for r in opt_records} for r in records: if r.components_meta_ is None: r._components = None else: r._components = [] for cm in r.components_meta_: rc = ReactionComponent(**cm.dict()) if rc.singlepoint_id is not None: rc.singlepoint_record = sp_map[rc.singlepoint_id] if rc.optimization_id is not None: rc.optimization_record = opt_map[rc.optimization_id] r._components.append(rc) r.propagate_client(r._client) def _fetch_components(self): if self.components_meta_ is None: self._assert_online() # Will include molecules self.components_meta_ = self._client.make_request( "get", f"api/v1/records/reaction/{self.id}/components", List[ReactionComponentMeta], ) self.fetch_children(["components"]) @property def components(self) -> List[ReactionComponent]: if self._components is None: self._fetch_components() return self._components