from __future__ import annotations
import os
import sys
from datetime import datetime
from enum import Enum
from typing import Optional, Dict, Any, List, Union, Iterable, Tuple, Type, Sequence, ClassVar, TypeVar
from dateutil.parser import parse as date_parser
try:
from pydantic.v1 import BaseModel, Extra, constr, validator, PrivateAttr, Field, parse_obj_as
except ImportError:
from pydantic import BaseModel, Extra, constr, validator, PrivateAttr, Field, parse_obj_as
from qcelemental.models.results import Provenance
from qcportal.base_models import (
RestModelBase,
QueryModelBase,
QueryIteratorBase,
)
from qcportal.cache import RecordCache, get_records_with_cache
from qcportal.compression import CompressionEnum, decompress, get_compressed_ext
_T = TypeVar("_T")
[docs]
class PriorityEnum(int, Enum):
"""
The priority of a Task. Higher priority will be pulled first.
"""
high = 2
normal = 1
low = 0
@classmethod
def _missing_(cls, priority):
"""Attempts to find the correct priority in a case-insensitive way
If a string being converted to a PriorityEnum is missing, then this function
will convert the case and try to find the appropriate priority.
"""
if isinstance(priority, int):
# An integer that is outside the range of valid priorities
return
priority = priority.lower()
# Search this way rather than doing 'in' since we are comparing
# a string to an enum
for p in cls:
if priority == p.name:
return p
[docs]
class RecordStatusEnum(str, Enum):
"""
The state of a record object. The states which are available are a finite set.
"""
# This ordering shouldn't change in the near future, as it conflicts
# a bit with some migration testing
complete = "complete"
invalid = "invalid"
running = "running"
error = "error"
waiting = "waiting"
cancelled = "cancelled"
deleted = "deleted"
@classmethod
def _missing_(cls, name):
"""Attempts to find the correct status in a case-insensitive way
If a string being converted to a RecordStatusEnum is missing, then this function
will convert the case and try to find the appropriate status.
"""
name = name.lower()
# Search this way rather than doing 'in' since we are comparing
# a string to an enum
for status in cls:
if name == status:
return status
[docs]
@classmethod
def make_ordered_status(cls, statuses: Iterable[RecordStatusEnum]) -> List[RecordStatusEnum]:
"""Returns a list of the given statuses but in a defined order"""
order = [cls.complete, cls.error, cls.running, cls.waiting, cls.cancelled, cls.invalid, cls.deleted]
return sorted(statuses, key=lambda x: order.index(x))
class OutputTypeEnum(str, Enum):
"""
What type of data is stored
"""
stdout = "stdout"
stderr = "stderr"
error = "error"
class OutputStore(BaseModel):
"""
Storage of outputs and error messages, with optional compression
"""
class Config:
extra = Extra.forbid
output_type: OutputTypeEnum = Field(..., description="The type of output this is (stdout, error, etc)")
compression_type: CompressionEnum = Field(CompressionEnum.none, description="Compression method (such as lzma)")
data_: Optional[bytes] = Field(None, alias="data")
_data_url: Optional[str] = PrivateAttr(None)
_client: Any = PrivateAttr(None)
def propagate_client(self, client, history_base_url):
self._client = client
self._data_url = f"{history_base_url}/outputs/{self.output_type.value}/data"
def _fetch_raw_data(self):
if self.data_ is not None:
return
if self._client is None:
raise RuntimeError("No client to fetch output data from")
cdata, ctype = self._client.make_request(
"get",
self._data_url,
Tuple[bytes, CompressionEnum],
)
assert self.compression_type == ctype
self.data_ = cdata
@property
def data(self) -> Any:
self._fetch_raw_data()
return decompress(self.data_, self.compression_type)
class ComputeHistory(BaseModel):
class Config:
extra = Extra.forbid
id: int
record_id: int
status: RecordStatusEnum
manager_name: Optional[str]
modified_on: datetime
provenance: Optional[Provenance]
outputs_: Optional[Dict[str, OutputStore]] = Field(None, alias="outputs")
_client: Any = PrivateAttr(None)
_base_url: Optional[str] = PrivateAttr(None)
def propagate_client(self, client, record_base_url):
self._client = client
self._base_url = f"{record_base_url}/compute_history/{self.id}"
if self.outputs_ is not None:
for o in self.outputs_.values():
o.propagate_client(self._client, self._base_url)
def fetch_all(self):
self._fetch_outputs()
def _fetch_outputs(self):
if self._client is None:
raise RuntimeError("This compute history is not connected to a client")
self.outputs_ = self._client.make_request(
"get",
f"{self._base_url}/outputs",
Dict[str, OutputStore],
)
for o in self.outputs_.values():
o.propagate_client(self._client, self._base_url)
o._fetch_raw_data()
@property
def outputs(self) -> Dict[str, OutputStore]:
if self.outputs_ is None:
self._fetch_outputs()
return self.outputs_
def get_output(self, output_type: OutputTypeEnum) -> Any:
if not self.outputs:
return None
o = self.outputs.get(output_type, None)
if o is None:
return None
else:
return o.data
@property
def stdout(self) -> Any:
return self.get_output("stdout")
@property
def stderr(self) -> Any:
return self.get_output("stderr")
@property
def error(self) -> Any:
return self.get_output("error")
class NativeFile(BaseModel):
"""
Storage of native files, with compression
"""
class Config:
extra = Extra.forbid
name: str = Field(..., description="Name of the file")
compression_type: CompressionEnum = Field(..., description="Compression method (such as lzma)")
data_: Optional[bytes] = Field(None, alias="data")
_data_url: Optional[str] = PrivateAttr(None)
_client: Any = PrivateAttr(None)
def propagate_client(self, client, record_base_url):
self._client = client
self._data_url = f"{record_base_url}/native_files/{self.name}/data"
def fetch_all(self):
self._fetch_raw_data()
def _fetch_raw_data(self):
if self.data_ is not None:
return
if self._client is None:
raise RuntimeError("No client to fetch native file data from")
cdata, ctype = self._client.make_request(
"get",
self._data_url,
Tuple[bytes, CompressionEnum],
)
assert self.compression_type == ctype
self.data_ = cdata
@property
def data(self) -> Any:
self._fetch_raw_data()
return decompress(self.data_, self.compression_type)
def save_file(
self, directory: str, new_name: Optional[str] = None, keep_compressed: bool = False, overwrite: bool = False
):
"""
Saves the file to the given directory
"""
if new_name is None:
name = self.name
else:
name = new_name
if keep_compressed:
name += get_compressed_ext(self.compression_type)
full_path = os.path.join(directory, name)
if os.path.exists(full_path) and not overwrite:
raise RuntimeError(f"File {full_path} already exists. Not overwriting")
if keep_compressed:
with open(full_path, "wb") as f:
f.write(self.data)
else:
d = self.data
# TODO - streaming decompression?
if isinstance(d, str):
with open(full_path, "wt") as f:
f.write(self.data)
elif isinstance(d, bytes):
with open(full_path, "wb") as f:
f.write(self.data)
else:
raise RuntimeError(f"Cannot write data of type {type(d)} to a file")
class RecordInfoBackup(BaseModel):
class Config:
extra = Extra.forbid
old_status: RecordStatusEnum
old_tag: Optional[str]
old_priority: Optional[PriorityEnum]
modified_on: datetime
class RecordComment(BaseModel):
class Config:
extra = Extra.forbid
id: int
record_id: int
username: Optional[str]
timestamp: datetime
comment: str
[docs]
class RecordTask(BaseModel):
class Config:
extra = Extra.forbid
id: int
record_id: int
function: Optional[str]
function_kwargs_compressed: Optional[bytes]
tag: str
priority: PriorityEnum
required_programs: List[str]
@property
def function_kwargs(self) -> Optional[Dict[str, Any]]:
if self.function_kwargs_compressed is None:
return None
else:
return decompress(self.function_kwargs_compressed, CompressionEnum.zstd)
class ServiceDependency(BaseModel):
class Config:
extra = Extra.forbid
record_id: int
extras: Dict[str, Any]
[docs]
class RecordService(BaseModel):
class Config:
extra = Extra.forbid
id: int
record_id: int
tag: str
priority: PriorityEnum
find_existing: bool
service_state: Optional[Dict[str, Any]] = None
dependencies: List[ServiceDependency]
[docs]
class BaseRecord(BaseModel):
class Config:
extra = Extra.forbid
allow_mutation = True
validate_assignment = True
id: int
record_type: str
is_service: bool
properties: Optional[Dict[str, Any]]
extras: Dict[str, Any] = Field({})
status: RecordStatusEnum
manager_name: Optional[str]
created_on: datetime
modified_on: datetime
owner_user: Optional[str]
owner_group: Optional[str]
######################################################
# Fields not always included when fetching the record
######################################################
compute_history_: Optional[List[ComputeHistory]] = Field(None, alias="compute_history")
task_: Optional[RecordTask] = Field(None, alias="task")
service_: Optional[RecordService] = Field(None, alias="service")
comments_: Optional[List[RecordComment]] = Field(None, alias="comments")
native_files_: Optional[Dict[str, NativeFile]] = Field(None, alias="native_files")
# Private non-pydantic fields
_client: Any = PrivateAttr(None)
_base_url: str = PrivateAttr(None)
# A dictionary of all subclasses (calculation types) to actual class type
_all_subclasses: ClassVar[Dict[str, Type[BaseRecord]]] = {}
# Local record cache we can use for child records
# This record may also be part of the cache
_record_cache: Optional[RecordCache] = PrivateAttr(None)
_cache_dirty: bool = PrivateAttr(False)
def __init__(self, client=None, **kwargs):
BaseModel.__init__(self, **kwargs)
# Calls derived class propagate_client
# which should filter down to the ones in this (BaseRecord) class
self.propagate_client(client)
assert self._client is client, "Client not set in base record class?"
@validator("extras", pre=True)
def _validate_extras(cls, v):
# For backwards compatibility. Older servers may have 'None' as the extras
if v is None:
return {}
return v
def __init_subclass__(cls):
"""
Register derived classes for later use
"""
# Get the record type. This is kind of ugly, but works.
# We could use ClassVar, but in my tests it doesn't work for
# disambiguating (ie, via parse_obj_as)
record_type = cls.__fields__["record_type"].default
cls._all_subclasses[record_type] = cls
def __del__(self):
# Sometimes this won't exist if there is an exception during construction
# TODO - we check sys.meta_path. Pydantic attempts an import of something, which is
# not good if the interpreter is shutting down. This is a hack to avoid that.
# Pydantic v2 may fix this
if (
hasattr(self, "_record_cache")
and self._record_cache is not None
and not self._record_cache.read_only
and self._cache_dirty
and sys.meta_path is not None
):
self.sync_to_cache(True) # Don't really *have* to detach, but why not
s = super()
if hasattr(s, "__del__"):
s.__del__(self)
[docs]
@classmethod
def get_subclass(cls, record_type: str) -> Type[BaseRecord]:
"""
Obtain a subclass of this class given its record_type
"""
subcls = cls._all_subclasses.get(record_type)
if subcls is None:
raise RuntimeError(f"Cannot find subclass for record type {record_type}")
return subcls
@classmethod
def _fetch_children_multi(
cls, client, record_cache, records: Iterable[BaseRecord], include: Iterable[str], force_fetch: bool = False
):
"""
Fetches all children of the given records recursively
This tries to work efficiently, fetching larger batches of children
that can span multiple records
Meant to be overridden by derived classes
"""
pass
[docs]
@classmethod
def fetch_children_multi(
cls, records: Iterable[Optional[BaseRecord]], include: Optional[Iterable[str]] = None, force_fetch: bool = False
):
"""
Fetches all children of the given records
This tries to work efficiently, fetching larger batches of children
that can span multiple records
"""
# Remove any None records
# can happen if missing_ok=True in some function calls
records = [r for r in records if r is not None]
if not records:
return
# Get the first record (for the client and other info)
template_record = next(iter(records))
if not all(isinstance(r, type(template_record)) for r in records):
raise RuntimeError("Fetching children of records with different types is not supported.")
if not all(r._client is template_record._client for r in records):
raise RuntimeError("Fetching children of records with different clients is not supported.")
if not all(r._record_cache is template_record._record_cache for r in records):
raise RuntimeError("Fetching children of records with different record caches is not supported.")
# Call the derived class function
if include is None:
include = []
cls._fetch_children_multi(
template_record._client, template_record._record_cache, records, include=include, force_fetch=force_fetch
)
[docs]
def fetch_children(self, include: Optional[Iterable[str]] = None, force_fetch: bool = False):
"""
Fetches all children of this record recursively
"""
self.fetch_children_multi([self], include, force_fetch)
[docs]
def sync_to_cache(self, detach: bool = False):
"""
Syncs this record to the cache
If `detach` is True, then the record will be removed from the cache
"""
if self._record_cache is None:
return
if self._record_cache.read_only:
return
self._record_cache.writeback_record(self)
self._cache_dirty = False
if detach:
self._record_cache = None
def __str__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} status={self.status}>"
[docs]
def propagate_client(self, client):
"""
Propagates a client and related information to this record to any fields within this record that need it
This is expected to be called from derived class propagate_client functions as well
"""
self._client = client
self._base_url = f"api/v1/records/{self.record_type}/{self.id}"
if self.compute_history_ is not None:
for ch in self.compute_history_:
ch.propagate_client(self._client, self._base_url)
if self.native_files_ is not None:
for nf in self.native_files_.values():
nf.propagate_client(self._client, self._base_url)
def _get_child_records(
self,
child_record_ids: Sequence[int],
child_record_type: Type[_Record_T],
include: Optional[Iterable[str]] = None,
) -> List[_Record_T]:
"""
Helper function for obtaining child records either from the cache or from the server
The records are returned in the same order as the `record_ids` parameter.
If `include` is specified, additional fields will be fetched from the server. However, if the records are in the
cache already, they may be missing those fields.
"""
return get_records_with_cache(
self._client, self._record_cache, child_record_type, child_record_ids, include, force_fetch=False
)
def _assert_online(self):
"""Raises an exception if this record does not have an associated client"""
if self.offline:
raise RuntimeError("Record is not connected to a client")
def _fetch_compute_history(self):
self._assert_online()
self.compute_history_ = self._client.make_request(
"get", f"{self._base_url}/compute_history", List[ComputeHistory]
)
self.propagate_client(self._client)
def _fetch_task(self):
if self.is_service:
self.task_ = None
else:
self.task_ = self._client.make_request("get", f"{self._base_url}/task", Optional[RecordTask])
def _fetch_service(self):
if not self.is_service:
self.service_ = None
else:
self.service_ = self._client.make_request("get", f"{self._base_url}/service", Optional[RecordService])
def _fetch_comments(self):
self._assert_online()
self.comments_ = self._client.make_request("get", f"{self._base_url}/comments", List[RecordComment])
def _fetch_native_files(self):
self.native_files_ = self._client.make_request("get", f"{self._base_url}/native_files", Dict[str, NativeFile])
self.propagate_client(self._client)
def _get_output(self, output_type: OutputTypeEnum) -> Optional[Union[str, Dict[str, Any]]]:
history = self.compute_history
if not history:
return None
return history[-1].get_output(output_type)
@property
def offline(self) -> bool:
return self._client is None
@property
def children_status(self) -> Dict[RecordStatusEnum, int]:
"""Returns a dictionary of the status of all children of this record"""
self._assert_online()
return self._client.make_request(
"get",
f"{self._base_url}/children_status",
Dict[RecordStatusEnum, int],
)
@property
def children_errors(self) -> List[BaseRecord]:
"""Returns errored child records"""
self._assert_online()
error_ids = self._client.make_request(
"get",
f"{self._base_url}/children_errors",
List[int],
)
return self._client._get_records_by_type(None, error_ids)
@property
def compute_history(self) -> List[ComputeHistory]:
if self.compute_history_ is None:
self._fetch_compute_history()
return self.compute_history_
@property
def task(self) -> Optional[RecordTask]:
# task_ may be None because it either hasn't been fetched or it doesn't exist
# fetch only if it has been set at some point
if self.task_ is None and "task_" not in self.__fields_set__:
self._fetch_task()
return self.task_
@property
def service(self) -> Optional[RecordService]:
# service_ may be None because it either hasn't been fetched or it doesn't exist
# fetch only if it has been set at some point
if self.service_ is None and "service_" not in self.__fields_set__:
self._fetch_service()
return self.service_
[docs]
def get_waiting_reason(self) -> Dict[str, Any]:
return self._client.make_request("get", f"api/v1/records/{self.id}/waiting_reason", Dict[str, Any])
@property
def comments(self) -> Optional[List[RecordComment]]:
if self.comments_ is None:
self._fetch_comments()
return self.comments_
@property
def native_files(self) -> Optional[Dict[str, NativeFile]]:
if self.native_files_ is None:
self._fetch_native_files()
return self.native_files_
@property
def stdout(self) -> Optional[str]:
return self._get_output(OutputTypeEnum.stdout)
@property
def stderr(self) -> Optional[str]:
return self._get_output(OutputTypeEnum.stderr)
@property
def error(self) -> Optional[Dict[str, Any]]:
return self._get_output(OutputTypeEnum.error)
@property
def provenance(self) -> Optional[Provenance]:
history = self.compute_history
if not history:
return None
return history[-1].provenance
ServiceDependency.update_forward_refs()
_Record_T = TypeVar("_Record_T", bound=BaseRecord)
class RecordAddBodyBase(RestModelBase):
tag: constr(to_lower=True)
priority: PriorityEnum
owner_group: Optional[str]
find_existing: bool = True
class RecordModifyBody(RestModelBase):
record_ids: List[int]
status: Optional[RecordStatusEnum] = None
priority: Optional[PriorityEnum] = None
tag: Optional[str] = None
comment: Optional[str] = None
class RecordDeleteBody(RestModelBase):
record_ids: List[int]
soft_delete: bool
delete_children: bool
class RecordRevertBody(RestModelBase):
revert_status: RecordStatusEnum
record_ids: List[int]
class RecordQueryFilters(QueryModelBase):
record_id: Optional[List[int]] = None
record_type: Optional[List[str]] = None
manager_name: Optional[List[str]] = None
status: Optional[List[RecordStatusEnum]] = None
dataset_id: Optional[List[int]] = None
parent_id: Optional[List[int]] = None
child_id: Optional[List[int]] = None
created_before: Optional[datetime] = None
created_after: Optional[datetime] = None
modified_before: Optional[datetime] = None
modified_after: Optional[datetime] = None
owner_user: Optional[List[Union[int, str]]] = None
owner_group: Optional[List[Union[int, str]]] = None
@validator("created_before", "created_after", "modified_before", "modified_after", pre=True)
def parse_dates(cls, v):
if isinstance(v, str):
return date_parser(v)
return v
[docs]
class RecordQueryIterator(QueryIteratorBase[_Record_T]):
"""
Iterator for all types of record queries
This iterator transparently handles batching and pagination over the results
of a record query, and works with all kinds of records.
"""
def __init__(
self,
client,
query_filters: RecordQueryFilters,
record_type: Type[_Record_T],
include: Optional[Iterable[str]] = None,
):
"""
Construct an iterator
Parameters
----------
client
QCPortal client object used to contact/retrieve data from the server
query_filters
The actual query information to send to the server
record_type
What type of record we are querying for
"""
batch_limit = client.api_limits["get_records"] // 4
self.record_type = record_type
self.include = include
QueryIteratorBase.__init__(self, client, query_filters, batch_limit)
def _request(self) -> List[_Record_T]:
if self.record_type is None:
record_ids = self._client.make_request(
"post",
f"api/v1/records/query",
List[int],
body=self._query_filters,
)
else:
# Get the record type string. This is kind of ugly, but works.
record_type_str = self.record_type.__fields__["record_type"].default
record_ids = self._client.make_request(
"post",
f"api/v1/records/{record_type_str}/query",
List[int],
body=self._query_filters,
)
return self._client._get_records_by_type(self.record_type, record_ids, include=self.include)
def record_from_dict(data: Dict[str, Any], client: Any = None) -> BaseRecord:
"""
Create a record object from a dictionary containing the record information
This determines the appropriate record class (deriving from BaseRecord)
and creates an instance of that class.
"""
record_type = data["record_type"]
cls = BaseRecord.get_subclass(record_type)
return cls(**data, client=client)
def records_from_dicts(
data: Sequence[Optional[Dict[str, Any]]],
client: Any = None,
) -> List[Optional[BaseRecord]]:
"""
Create a list of record objects from a sequence of datamodels
This determines the appropriate record class (deriving from BaseRecord)
and creates an instance of that class.
"""
ret: List[Optional[BaseRecord]] = []
for rd in data:
if rd is None:
ret.append(None)
else:
ret.append(record_from_dict(rd, client))
return ret