Source code for qcportal.metadata_models
from __future__ import annotations
import dataclasses
from typing import List, Optional, Tuple, Dict, Sequence, Any
try:
from pydantic.v1 import validator, root_validator
from pydantic.v1.dataclasses import dataclass
except ImportError:
from pydantic import validator, root_validator
from pydantic.dataclasses import dataclass
[docs]
@dataclass
class InsertMetadata:
"""
Metadata returned by insertion / adding functions
"""
# Integers in errors, inserted, existing are indices in the input/output list
error_description: Optional[str] = None
errors: List[Tuple[int, str]] = dataclasses.field(default_factory=list)
inserted_idx: List[int] = dataclasses.field(default_factory=list) # inserted into the db
existing_idx: List[int] = dataclasses.field(default_factory=list) # existing but not updated
@property
def n_inserted(self):
return len(self.inserted_idx)
@property
def n_existing(self):
return len(self.existing_idx)
@property
def n_errors(self):
return len(self.errors)
@property
def error_idx(self):
return [x[0] for x in self.errors]
@property
def success(self):
return self.error_description is None and len(self.errors) == 0
@property
def error_string(self):
s = ""
if self.error_description:
s += self.error_description + "\n"
s += "\n".join(f" Index {x}: {y}" for x, y in self.errors)
return s
[docs]
@validator("errors", "inserted_idx", "existing_idx", pre=True)
def sort_fields(cls, v):
return sorted(v)
[docs]
@root_validator(pre=False, skip_on_failure=True)
def check_all_indices(cls, values):
# Test that all indices are accounted for and that the same index doesn't show up in
# inserted_idx, existing_idx, or errors
ins_idx = set(values["inserted_idx"])
existing_idx = set(values["existing_idx"])
error_idx = set(x[0] for x in values["errors"])
if not ins_idx.isdisjoint(existing_idx):
intersection = ins_idx.intersection(existing_idx)
raise ValueError(f"inserted_idx and existing_idx are not disjoint: intersection={intersection}")
if not ins_idx.isdisjoint(error_idx):
intersection = ins_idx.intersection(error_idx)
raise ValueError(f"inserted_idx and error_idx are not disjoint: intersection={intersection}")
if not existing_idx.isdisjoint(error_idx):
intersection = existing_idx.intersection(error_idx)
raise ValueError(f"existing_idx and error_idx are not disjoint: intersection={intersection}")
all_idx = ins_idx | existing_idx | error_idx
# Skip the rest if we don't have any data
if len(all_idx) == 0:
return values
# Are all the indices accounted for?
all_possible = set(range(max(all_idx) + 1))
if all_idx != all_possible:
missing = all_possible - all_idx
raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}")
return values
[docs]
def dict(self) -> Dict[str, Any]:
"""
Returns the information from this dataclass as a dictionary
"""
return dataclasses.asdict(self)
[docs]
@staticmethod
def merge(metadata: Sequence[InsertMetadata]) -> InsertMetadata:
new_inserted_idx: List[int] = []
new_existing_idx: List[int] = []
new_errors: List[Tuple[int, str]] = []
new_error_description: Optional[str] = None
base_idx = 0
for m in metadata:
new_inserted_idx.extend(i + base_idx for i in m.inserted_idx)
new_existing_idx.extend(i + base_idx for i in m.existing_idx)
new_errors.extend((i + base_idx, e) for i, e in m.errors)
if m.error_description is not None:
if new_error_description is None:
new_error_description = m.error_description
else:
new_error_description += "\n" + m.error_description
base_idx += len(m.inserted_idx) + len(m.existing_idx) + len(m.errors)
return InsertMetadata(
inserted_idx=new_inserted_idx,
existing_idx=new_existing_idx,
errors=new_errors,
error_description=new_error_description,
)
[docs]
@dataclass
class DeleteMetadata:
"""
Metadata returned by delete functions
"""
# Integers in errors, missing, found are indices in the input/output list
error_description: Optional[str] = None
errors: List[Tuple[int, str]] = dataclasses.field(default_factory=list)
deleted_idx: List[int] = dataclasses.field(default_factory=list)
n_children_deleted: int = 0
@property
def n_deleted(self):
return len(self.deleted_idx)
@property
def n_errors(self):
return len(self.errors)
@property
def error_idx(self):
return [x[0] for x in self.errors]
@property
def success(self):
return self.error_description is None and len(self.errors) == 0
@property
def error_string(self):
s = ""
if self.error_description:
s += self.error_description + "\n"
s += "\n".join(f" Index {x}: {y}" for x, y in self.errors)
return s
[docs]
@root_validator(pre=False, skip_on_failure=True)
def check_all_indices(cls, values):
# Test that all indices are accounted for and that the same index doesn't show up in
# deleted_idx, or errors
del_idx = set(values["deleted_idx"])
error_idx = set(x[0] for x in values["errors"])
if not del_idx.isdisjoint(error_idx):
intersection = del_idx.intersection(error_idx)
raise ValueError(f"deleted_idx and error_idx are not disjoint: intersection={intersection}")
all_idx = del_idx | error_idx
# Skip the rest if we don't have any data
if len(all_idx) == 0:
return values
# Are all the indices accounted for?
all_possible = set(range(max(all_idx) + 1))
if all_idx != all_possible:
missing = all_possible - all_idx
raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}")
return values
[docs]
def dict(self) -> Dict[str, Any]:
"""
Returns the information from this dataclass as a dictionary
"""
return dataclasses.asdict(self)
[docs]
@dataclass
class UpdateMetadata:
"""
Metadata returned by update functions
"""
# Integers in errors, updated_idx
error_description: Optional[str] = None
errors: List[Tuple[int, str]] = dataclasses.field(default_factory=list)
updated_idx: List[int] = dataclasses.field(default_factory=list) # inserted into the db
n_children_updated: int = 0
@property
def n_updated(self):
return len(self.updated_idx)
@property
def n_errors(self):
return len(self.errors)
@property
def error_idx(self):
return [x[0] for x in self.errors]
@property
def success(self):
return self.error_description is None and len(self.errors) == 0
@property
def error_string(self):
s = ""
if self.error_description:
s += self.error_description + "\n"
s += "\n".join(f" Index {x}: {y}" for x, y in self.errors)
return s
[docs]
@root_validator(pre=False, skip_on_failure=True)
def check_all_indices(cls, values):
# Test that all indices are accounted for and that the same index doesn't show up in
# inserted_idx, existing_idx, or errors
upd_idx = set(values["updated_idx"])
error_idx = set(x[0] for x in values["errors"])
if not upd_idx.isdisjoint(error_idx):
intersection = upd_idx.intersection(error_idx)
raise ValueError(f"updated_idx and error_idx are not disjoint: intersection={intersection}")
all_idx = upd_idx | error_idx
# Skip the rest if we don't have any data
if len(all_idx) == 0:
return values
# Are all the indices accounted for?
all_possible = set(range(max(all_idx) + 1))
if all_idx != all_possible:
missing = all_possible - all_idx
raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}")
return values
[docs]
def dict(self) -> Dict[str, Any]:
"""
Returns the information from this dataclass as a dictionary
"""
return dataclasses.asdict(self)
[docs]
@dataclass
class TaskReturnMetadata:
"""
Metadata returned to managers that have sent completed tasks back to the server
"""
# Integers in errors, accepted_ids are task ids
error_description: Optional[str] = None
rejected_info: List[Tuple[int, str]] = dataclasses.field(default_factory=list)
accepted_ids: List[int] = dataclasses.field(default_factory=list) # Accepted by the server
@property
def n_accepted(self):
return len(self.accepted_ids)
@property
def n_rejected(self):
return len(self.rejected_ids)
@property
def rejected_ids(self):
return [x[0] for x in self.rejected_info]
@property
def success(self):
return self.error_description is None
@property
def error_string(self):
s = ""
if self.error_description:
s += self.error_description + "\n"
s += "\n".join(f" Task id {x}: {y}" for x, y in self.rejected_info)
return s
[docs]
@validator("rejected_info", "accepted_ids", pre=True)
def sort_fields(cls, v):
return sorted(v)
[docs]
def dict(self) -> Dict[str, Any]:
"""
Returns the information from this dataclass as a dictionary
"""
return dataclasses.asdict(self)