Source code for qcportal.optimization.record_models

from __future__ import annotations

from copy import deepcopy
from typing import Optional, Union, Any, List, Dict

try:
    from pydantic.v1 import BaseModel, Field, constr, validator, Extra, PrivateAttr
except ImportError:
    from pydantic import BaseModel, Field, constr, validator, Extra, PrivateAttr
from qcelemental.models import Molecule
from qcelemental.models.procedures import (
    OptimizationResult,
    OptimizationProtocols,
    QCInputSpecification,
    Model as AtomicResultModel,
)
from typing_extensions import Literal
from typing import Iterable

from qcportal.record_models import BaseRecord, RecordAddBodyBase, RecordQueryFilters, RecordStatusEnum
from qcportal.cache import get_records_with_cache
from qcportal.singlepoint import SinglepointProtocols, SinglepointRecord, QCSpecification, SinglepointDriver


[docs] class OptimizationSpecification(BaseModel): """ An OptimizationSpecification as stored on the server This is the same as the input specification, with a few ids added """
[docs] class Config: extra = Extra.forbid
program: constr(to_lower=True) = Field(..., description="The program to use for an optimization") qc_specification: QCSpecification keywords: Dict[str, Any] = Field({}) protocols: OptimizationProtocols = Field(OptimizationProtocols())
[docs] @validator("qc_specification", pre=True) def force_qcspec(cls, v): if isinstance(v, QCSpecification): v = v.dict() v["driver"] = SinglepointDriver.deferred v["protocols"] = SinglepointProtocols() return v
[docs] class OptimizationRecord(BaseRecord): record_type: Literal["optimization"] = "optimization" specification: OptimizationSpecification initial_molecule_id: int final_molecule_id: Optional[int] energies: Optional[List[float]] ###################################################### # Fields not always included when fetching the record ###################################################### initial_molecule_: Optional[Molecule] = Field(None, alias="initial_molecule") final_molecule_: Optional[Molecule] = Field(None, alias="final_molecule") trajectory_ids_: Optional[List[int]] = Field(None, alias="trajectory_ids") ######################################## # Caches ######################################## _trajectory_records: Optional[List[SinglepointRecord]] = PrivateAttr(None) @classmethod def _fetch_children_multi( cls, client, record_cache, records: Iterable[OptimizationRecord], include: Iterable[str], force_fetch: bool = False, ): # Should be checked by the calling function assert records assert all(isinstance(x, OptimizationRecord) for x in records) if "trajectory" in include or "**" in include: # collect all singlepoint ids for all optimizations sp_ids = set() for r in records: if r.trajectory_ids_: sp_ids.update(r.trajectory_ids_) sp_ids = list(sp_ids) sp_records = get_records_with_cache( client, record_cache, SinglepointRecord, sp_ids, include=include, force_fetch=force_fetch ) sp_map = {r.id: r for r in sp_records} for r in records: if r.trajectory_ids_ is None: r._trajectory_records = None else: r._trajectory_records = [sp_map[x] for x in r.trajectory_ids_] r.propagate_client(r._client)
[docs] def propagate_client(self, client): BaseRecord.propagate_client(self, client) if self._trajectory_records is not None: for sp in self._trajectory_records: sp.propagate_client(client)
def _fetch_initial_molecule(self): self._assert_online() self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0] def _fetch_final_molecule(self): if self.final_molecule_id is not None: self._assert_online() self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0] def _fetch_trajectory(self): if self.trajectory_ids_ is None: self._assert_online() self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", List[int], ) self.fetch_children(["trajectory"]) @property def initial_molecule(self) -> Molecule: if self.initial_molecule_ is None: self._fetch_initial_molecule() return self.initial_molecule_ @property def final_molecule(self) -> Optional[Molecule]: if self.final_molecule_ is None: self._fetch_final_molecule() return self.final_molecule_ @property def trajectory(self) -> Optional[List[SinglepointRecord]]: if self._trajectory_records is None: self._fetch_trajectory() return self._trajectory_records
[docs] def trajectory_element(self, trajectory_index: int) -> SinglepointRecord: if self._trajectory_records is not None: return self._trajectory_records[trajectory_index] else: self._assert_online() if self.trajectory_ids_ is None: self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", List[int], ) if self.trajectory_ids_ is not None: traj_id = self.trajectory_ids_[trajectory_index] sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0] sp_rec.propagate_client(self._client) return sp_rec else: raise RuntimeError(f"Cannot find trajectory for record {self.id}")
[docs] def to_qcschema_result(self) -> OptimizationResult: if self.status != RecordStatusEnum.complete: raise RuntimeError(f"Cannot create QCSchema result from record with status {self.status}") extras = deepcopy(self.extras) extras["_qcfractal_modified_on"] = self.compute_history[0].modified_on if self.trajectory is not None: trajectory = [x.to_qcschema_result() for x in self.trajectory] else: trajectory = None # TODO - correct? new_keywords = deepcopy(self.specification.keywords) new_keywords["program"] = self.specification.qc_specification.program return OptimizationResult( initial_molecule=self.initial_molecule, final_molecule=self.final_molecule, trajectory=trajectory, energies=self.energies, keywords=new_keywords, input_specification=QCInputSpecification( driver=SinglepointDriver.gradient, # forced model=AtomicResultModel( method=self.specification.qc_specification.method, basis=self.specification.qc_specification.basis, ), keywords=self.specification.qc_specification.keywords, ), protocols=self.specification.protocols, extras=extras, stdout=self.stdout, provenance=self.provenance, success=True, # Status has been checked above )
class OptimizationQueryFilters(RecordQueryFilters): program: Optional[List[str]] = None qc_program: Optional[List[constr(to_lower=True)]] = None qc_method: Optional[List[constr(to_lower=True)]] = None qc_basis: Optional[List[Optional[constr(to_lower=True)]]] = None initial_molecule_id: Optional[List[int]] = None final_molecule_id: Optional[List[int]] = None @validator("qc_basis") def _convert_basis(cls, v): # Convert empty string to None # Lowercasing is handled by constr if v is not None: return ["" if x is None else x for x in v] else: return None class OptimizationAddBody(RecordAddBodyBase): specification: OptimizationSpecification initial_molecules: List[Union[int, Molecule]]