from __future__ import annotations
import math
import os
from datetime import datetime
from typing import (
TYPE_CHECKING,
Optional,
Dict,
Any,
List,
Iterable,
Type,
Tuple,
Union,
Callable,
ClassVar,
Sequence,
Mapping,
)
try:
import pydantic.v1 as pydantic
from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr, Field
except ImportError:
import pydantic
from pydantic import BaseModel, Extra, validator, PrivateAttr, Field
from qcelemental.models.types import Array
from tabulate import tabulate
from tqdm import tqdm
from qcportal.base_models import RestModelBase, validate_list_to_single, CommonBulkGetBody
from qcportal.metadata_models import DeleteMetadata
from qcportal.metadata_models import InsertMetadata
from qcportal.record_models import PriorityEnum, RecordStatusEnum, BaseRecord
from qcportal.utils import make_list, chunk_iterable
from qcportal.cache import DatasetCache, read_dataset_metadata, get_records_with_cache
if TYPE_CHECKING:
from qcportal.client import PortalClient
from pandas import DataFrame
class Citation(BaseModel):
"""A literature citation."""
class Config:
extra = Extra.forbid
allow_mutation = False
acs_citation: Optional[str] = None # hand-formatted citation in ACS style
bibtex: Optional[str] = None # bibtex blob for later use with bibtex-renderer
doi: Optional[str] = None
url: Optional[str] = None
def to_acs(self) -> str:
"""Returns an ACS-formatted citation"""
return self.acs_citation
class ContributedValues(BaseModel):
class Config:
extra = Extra.forbid
allow_mutation = False
name: str
values: Any
index: Array[str]
values_structure: Dict[str, Any] = {}
theory_level: Union[str, Dict[str, str]]
units: str
theory_level_details: Optional[Union[str, Dict[str, Optional[str]]]] = None
citations: Optional[List[Citation]] = None
external_url: Optional[str] = None
doi: Optional[str] = None
comments: Optional[str] = None
[docs]
class BaseDataset(BaseModel):
class Config:
extra = Extra.forbid
allow_mutation = True
validate_assignment = True
id: int
dataset_type: str
name: str
description: str
tagline: str
tags: List[str]
group: str
visibility: bool
provenance: Dict[str, Any]
default_tag: str
default_priority: PriorityEnum
owner_user: Optional[str]
owner_group: Optional[str]
metadata: Dict[str, Any]
extras: Dict[str, Any]
########################################
# Caches of information
########################################
_entry_names: List[str] = PrivateAttr([])
_specification_names: List[str] = PrivateAttr([])
# All local cache data. May be backed by memory or disk
_cache_data: DatasetCache = PrivateAttr()
# Values computed outside QCA
_contributed_values: Optional[Dict[str, ContributedValues]] = PrivateAttr(None)
#############################
# Private non-pydantic fields
#############################
_client: Any = PrivateAttr(None)
# To be overridden by the derived classes
_entry_type: ClassVar[Optional[Type]] = None
_new_entry_type: ClassVar[Optional[Type]] = None
_specification_type: ClassVar[Optional[Type]] = None
_record_item_type: ClassVar[Optional[Type]] = None
_record_type: ClassVar[Optional[Type]] = None
# A dictionary of all subclasses (dataset types) to the actual class type
_all_subclasses: ClassVar[Dict[str, Type[BaseDataset]]] = {}
# Some dataset options
auto_fetch_missing: bool = True # Automatically fetch missing records from the server
def __init__(self, client: Optional[PortalClient] = None, cache_data: Optional[DatasetCache] = None, **kwargs):
BaseModel.__init__(self, **kwargs)
# Calls derived class propagate_client
# which should filter down to the ones in this (BaseDataset) class
self.propagate_client(client)
assert self._client is client, "Client not set in base dataset class?"
if cache_data is not None:
# Passed in cache data. That takes priority
self._cache_data = cache_data
elif self._client:
# Ask the client cache for our cache
self._cache_data = client.cache.get_dataset_cache(self.id, type(self))
else:
# Memory_backed cache, not shared
# TODO - share? Use class id as a key? Would allow for threading
self._cache_data = DatasetCache("file:/?mode=memory", False, type(self))
if not self._cache_data.read_only:
# Add metadata to cache file (in case the user wants to share it)
self._cache_data.update_metadata("dataset_metadata", self)
# Only address, not username/password
self._cache_data.update_metadata("client_address", self._client.address)
def __init_subclass__(cls):
"""
Register derived classes for later use
"""
# Get the dataset type. This is kind of ugly, but works.
# We could use ClassVar, but in my tests it doesn't work for
# disambiguating (ie, via parse_obj_as)
dataset_type = cls.__fields__["dataset_type"].default
cls._all_subclasses[dataset_type] = cls
[docs]
@classmethod
def get_subclass(cls, dataset_type: str):
subcls = cls._all_subclasses.get(dataset_type)
if subcls is None:
raise RuntimeError(f"Cannot find subclass for record type {dataset_type}")
return subcls
[docs]
def propagate_client(self, client):
"""
Propagates a client to this record to any fields within this record that need it
This may also be called from derived class propagate_client functions as well
"""
self._client = client
def _add_entries(self, entries: Union[BaseModel, Sequence[BaseModel]]) -> InsertMetadata:
"""
Internal function for adding entries to a dataset
This function handles batching and some type checking
Parameters
----------
entries
Entries to add. May be just a single entry or a sequence of entries
"""
entries = make_list(entries)
if len(entries) == 0:
return InsertMetadata()
assert all(isinstance(x, self._new_entry_type) for x in entries), "Incorrect entry type"
uri = f"api/v1/datasets/{self.dataset_type}/{self.id}/entries/bulkCreate"
batch_size: int = math.ceil(self._client.api_limits["get_dataset_entries"] / 4)
n_batches = math.ceil(len(entries) / batch_size)
all_meta: List[InsertMetadata] = []
for entry_batch in tqdm(chunk_iterable(entries, batch_size), total=n_batches, disable=None):
meta = self._client.make_request("post", uri, InsertMetadata, body=entry_batch)
# If entry names have been fetched, add the new entry names
# This should still be ok if there are no entries - they will be fetched if the list is empty
added_names = [x.name for x in entry_batch]
all_meta.append(meta)
self._internal_fetch_entries(added_names)
return InsertMetadata.merge(all_meta)
def _add_specifications(self, specifications: Union[BaseModel, Sequence[BaseModel]]) -> InsertMetadata:
"""
Internal function for adding specifications to a dataset
Parameters
----------
specifications
Specifications to add. May be just a single specification or a sequence of entries
"""
specifications = make_list(specifications)
if len(specifications) == 0:
return InsertMetadata()
assert all(isinstance(x, self._specification_type) for x in specifications), "Incorrect specification type"
uri = f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications"
ret = self._client.make_request("post", uri, InsertMetadata, body=specifications)
added_names = [x.name for x in specifications]
self._internal_fetch_specifications(added_names)
return ret
def _update_metadata(self, **kwargs):
self.assert_online()
new_body = {
"name": self.name,
"description": self.description,
"tagline": self.tagline,
"tags": self.tags,
"group": self.group,
"visibility": self.visibility,
"provenance": self.provenance,
"default_tag": self.default_tag,
"default_priority": self.default_priority,
"metadata": self.metadata,
}
new_body.update(**kwargs)
body = DatasetModifyMetadata(**new_body)
self._client.make_request("patch", f"api/v1/datasets/{self.dataset_type}/{self.id}", None, body=body)
self.name = body.name
self.description = body.description
self.tagline = body.tagline
self.tags = body.tags
self.group = body.group
self.visibility = body.visibility
self.provenance = body.provenance
self.default_tag = body.default_tag
self.default_priority = body.default_priority
self.metadata = body.metadata
self._cache_data.update_metadata("dataset_metadata", self)
[docs]
def submit(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
tag: Optional[str] = None,
priority: PriorityEnum = None,
find_existing: bool = True,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
# Do automatic batching here
# (will be removed when we move to async)
if entry_names is None:
entry_names = self.entry_names
if specification_names is None:
specification_names = self.specification_names
batch_size = math.ceil(self._client.api_limits["get_records"] / 4)
n_batches = math.ceil(len(entry_names) / batch_size)
for spec in specification_names:
for entry_batch in tqdm(chunk_iterable(entry_names, batch_size), total=n_batches, disable=None):
body_data = DatasetSubmitBody(
entry_names=entry_batch,
specification_names=[spec],
tag=tag,
priority=priority,
find_existing=find_existing,
)
self._client.make_request(
"post", f"api/v1/datasets/{self.dataset_type}/{self.id}/submit", Any, body=body_data
)
#########################################
# Various properties and getters/setters
#########################################
@property
def is_view(self) -> bool:
return self._cache_data is not None and self._cache_data.read_only
[docs]
def status(self) -> Dict[str, Any]:
self.assert_online()
return self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/status",
Dict[str, Dict[RecordStatusEnum, int]],
)
[docs]
def status_table(self) -> str:
"""
Returns the status of the dataset's computations as a table (in a string)
"""
ds_status = self.status()
all_status = {x for y in ds_status.values() for x in y}
ordered_status = RecordStatusEnum.make_ordered_status(all_status)
headers = ["specification"] + [x.value for x in ordered_status]
table = []
for spec, spec_statuses in sorted(ds_status.items()):
row = [spec]
row.extend(spec_statuses.get(s, "") for s in ordered_status)
table.append(row)
return tabulate(table, headers=headers, stralign="right")
[docs]
def print_status(self) -> None:
print(self.status_table())
[docs]
def detailed_status(self) -> List[Tuple[str, str, RecordStatusEnum]]:
self.assert_online()
return self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/detailed_status",
List[Tuple[str, str, RecordStatusEnum]],
)
@property
def offline(self) -> bool:
return self._client is None
[docs]
def assert_online(self):
if self.offline:
raise RuntimeError("Dataset is not connected to a QCFractal server")
@property
def record_count(self) -> int:
self.assert_online()
return self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/record_count",
int,
)
@property
def computed_properties(self):
self.assert_online()
return self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/computed_properties",
Dict[str, List[str]],
)
[docs]
def assert_is_not_view(self):
if self.is_view:
raise RuntimeError("Dataset loaded from an offline view")
[docs]
def set_name(self, new_name: str):
self._update_metadata(name=new_name)
[docs]
def set_description(self, new_description: str):
self._update_metadata(description=new_description)
[docs]
def set_visibility(self, new_visibility: bool):
self._update_metadata(visibility=new_visibility)
[docs]
def set_group(self, new_group: str):
self._update_metadata(group=new_group)
[docs]
def set_tagline(self, new_tagline: str):
self._update_metadata(tagline=new_tagline)
[docs]
def set_provenance(self, new_provenance: Dict[str, Any]):
self._update_metadata(provenance=new_provenance)
[docs]
def set_default_tag(self, new_default_tag: str):
self._update_metadata(default_tag=new_default_tag)
[docs]
def set_default_priority(self, new_default_priority: PriorityEnum):
self._update_metadata(default_priority=new_default_priority)
###################################
# Specifications
###################################
[docs]
def fetch_specification_names(self) -> None:
"""
Fetch all entry names from the remote server
These are fetched and then stored internally, and not returned.
"""
self.assert_is_not_view()
self.assert_online()
self._specification_names = self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/specification_names",
List[str],
)
def _internal_fetch_specifications(
self,
specification_names: Iterable[str],
) -> None:
"""
Fetches specification information from the remote server, storing it internally
This does not do any checking for existing specifications, but is used to actually
request the data from the server.
Note: This function does not do any batching w.r.t. server API limits. It is expected that is done
before this function is called.
Parameters
----------
specification_names
Names of the specifications to fetch
"""
if not specification_names:
return
fetched_specifications = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications/bulkFetch",
Dict[str, self._specification_type],
body=DatasetFetchSpecificationBody(names=specification_names),
)
# The specifications contain their own names, so we don't need the keys
self._cache_data.update_specifications(fetched_specifications.values())
if self._specification_names is None:
self._specification_names = list(fetched_specifications.keys())
else:
self._specification_names.extend(
x for x in fetched_specifications.keys() if x not in self._specification_names
)
[docs]
def fetch_specifications(
self, specification_names: Optional[Union[str, Iterable[str]]] = None, force_refetch: bool = False
) -> None:
"""
Fetch specifications from the remote server, storing them internally
Parameters
----------
specification_names
Names of specifications to fetch. If None, fetch all specifications
force_refetch
If true, fetch data from the server even if it already exists locally
"""
self.assert_is_not_view()
self.assert_online()
if force_refetch:
self.fetch_specification_names()
# we make copies because _internal_fetch_specifications modifies _specification_names
if specification_names is None:
specification_names = self.specification_names.copy()
else:
specification_names = make_list(specification_names).copy()
# Strip out existing specifications if we aren't forcing refetching
if force_refetch:
specifications_tofetch = specification_names
else:
cached_specifications = set(self._cache_data.get_specification_names())
specifications_tofetch = set(specification_names) - cached_specifications
batch_size: int = math.ceil(self._client.api_limits["get_dataset_entries"] / 4)
for specification_names_batch in chunk_iterable(specifications_tofetch, batch_size):
self._internal_fetch_specifications(specification_names_batch)
@property
def specification_names(self) -> List[str]:
if not self._specification_names:
if self.is_view:
self._specification_names = self._cache_data.get_specification_names()
else:
self.fetch_specification_names()
return self._specification_names
@property
def specifications(self) -> Mapping[str, Any]:
specs = self._cache_data.get_all_specifications()
if not specs and not self.is_view:
self.fetch_specifications()
specs = self._cache_data.get_all_specifications()
return {s.name: s for s in specs}
[docs]
def rename_specification(self, old_name: str, new_name: str):
self.assert_is_not_view()
self.assert_online()
if old_name == new_name:
return
name_map = {old_name: new_name}
self._client.make_request(
"patch", f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications", None, body=name_map
)
# rename locally cached entries and stuff
self._specification_names = [name_map.get(x, x) for x in self._specification_names]
self._cache_data.rename_specification(old_name, new_name)
[docs]
def delete_specification(self, name: str, delete_records: bool = False) -> DeleteMetadata:
self.assert_is_not_view()
self.assert_online()
body = DatasetDeleteStrBody(names=[name], delete_records=delete_records)
ret = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/specifications/bulkDelete",
DeleteMetadata,
body=body,
)
# Delete locally-cached stuff
self._specification_names = [x for x in self._specification_names if x != name]
self._cache_data.delete_specification(name)
return ret
###################################
# Entries
###################################
[docs]
def fetch_entry_names(self) -> None:
"""
Fetch all entry names from the remote server
These are fetched and then stored internally, and not returned.
"""
self.assert_is_not_view()
self.assert_online()
self._entry_names = self._client.make_request(
"get",
f"api/v1/datasets/{self.dataset_type}/{self.id}/entry_names",
List[str],
)
def _internal_fetch_entries(
self,
entry_names: Iterable[str],
) -> None:
"""
Fetches entry information from the remote server, storing it internally
This does not do any checking for existing entries, but is used to actually
request the data from the server.
Note: This function does not do any batching w.r.t. server API limits. It is expected that is done
before this function is called.
Parameters
----------
entry_names
Names of the entries to fetch
"""
if not entry_names:
return
body = DatasetFetchEntryBody(names=entry_names)
fetched_entries = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/entries/bulkFetch",
Dict[str, self._entry_type],
body=body,
)
# The entries contain their own names, so we don't need the keys
self._cache_data.update_entries(fetched_entries.values())
if self._entry_names is None:
self._entry_names = list(fetched_entries.keys())
else:
self._entry_names.extend(x for x in fetched_entries.keys() if x not in self._entry_names)
[docs]
def fetch_entries(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
force_refetch: bool = False,
) -> None:
"""
Fetches entry information from the remote server, storing it internally
By default, already-fetched entries will not be fetched again, unless
`force_refetch` is True.
Parameters
----------
entry_names
Names of entries to fetch. If None, fetch all entries
force_refetch
If true, fetch data from the server even if it already exists locally
"""
self.assert_is_not_view()
self.assert_online()
# Reload entry names if we are forcing refetching
if force_refetch:
self.fetch_entry_names()
# if not specified, do all entries
# we make copies because _internal_fetch_entries modifies _entry_names
if entry_names is None:
entry_names = self.entry_names.copy()
else:
entry_names = make_list(entry_names).copy()
# Strip out existing entries if we aren't forcing refetching
if force_refetch:
entries_tofetch = entry_names
else:
cached_entries = set(self._cache_data.get_entry_names())
entries_tofetch = set(entry_names) - cached_entries
batch_size: int = math.ceil(self._client.api_limits["get_dataset_entries"] / 4)
for entry_names_batch in chunk_iterable(entries_tofetch, batch_size):
self._internal_fetch_entries(entry_names_batch)
[docs]
def get_entry(
self,
entry_name: str,
force_refetch: bool = False,
) -> Optional[Any]:
"""
Obtain entry information
The entry will be automatically fetched from the remote server if needed.
"""
entry = self._cache_data.get_entry(entry_name)
if entry is None and not self.is_view:
self.fetch_entries(entry_name, force_refetch=force_refetch)
entry = self._cache_data.get_entry(entry_name)
return entry
[docs]
def iterate_entries(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
force_refetch: bool = False,
):
"""
Iterate over all entries
This is used as a generator, and automatically fetches entries as needed
Parameters
----------
entry_names
Names of entries to iterate over. If None, iterate over all entries
force_refetch
If true, fetch data from the server even if it already exists locally
"""
#########################################################
# We duplicate a little bit of fetch_entries here, since
# we want to yield in the middle
#########################################################
# Reload entry names if we are forcing refetching
# Nothing to fetch if this is a view
if force_refetch and not self.is_view:
self.fetch_entry_names()
# if not specified, do all entries
# we make copies because fetching records can modify _entry_names
if entry_names is None:
entry_names = self.entry_names.copy()
else:
entry_names = make_list(entry_names).copy()
if self.is_view:
# Go one at a time. No need to "fetch"
for entry_name in entry_names:
entry = self._cache_data.get_entry(entry_name)
if entry is not None:
yield entry
else:
# Check local cache, but fetch from server
batch_size: int = math.ceil(self._client.api_limits["get_dataset_entries"] / 4)
# What we have cached already
cached_entries = set(self._cache_data.get_entry_names())
for entry_names_batch in chunk_iterable(entry_names, batch_size):
# If forcing refetching, then use the whole batch. Otherwise, strip out
# any existing entries
if force_refetch:
entries_tofetch = entry_names_batch
else:
# get what we have in the local cache
entries_tofetch = set(entry_names_batch) - cached_entries
if entries_tofetch:
self._internal_fetch_entries(entries_tofetch)
# Loop over the whole batch (not just what we fetched)
entry_data = self._cache_data.get_entries(entry_names_batch)
for entry in entry_data:
yield entry
@property
def entry_names(self) -> List[str]:
if not self._entry_names:
if self.is_view:
self._entry_names = self._cache_data.get_entry_names()
else:
self.fetch_entry_names()
return self._entry_names
[docs]
def rename_entries(self, name_map: Dict[str, str]):
self.assert_is_not_view()
self.assert_online()
# Remove renames which aren't actually different
name_map = {old_name: new_name for old_name, new_name in name_map.items() if old_name != new_name}
self._client.make_request(
"patch", f"api/v1/datasets/{self.dataset_type}/{self.id}/entries", None, body=name_map
)
# rename locally cached entries and stuff
self._entry_names = [name_map.get(x, x) for x in self._entry_names]
for old_name, new_name in name_map.items():
self._cache_data.rename_entry(old_name, new_name)
[docs]
def delete_entries(self, names: Union[str, Iterable[str]], delete_records: bool = False) -> DeleteMetadata:
self.assert_is_not_view()
self.assert_online()
names = make_list(names)
body = DatasetDeleteStrBody(names=names, delete_records=delete_records)
ret = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/entries/bulkDelete",
DeleteMetadata,
body=body,
)
# Delete locally-cached stuff
self._entry_names = [x for x in self._entry_names if x not in names]
for entry_name in names:
self._cache_data.delete_entry(entry_name)
return ret
###########################
# Records
###########################
def _internal_fetch_records(
self,
entry_names: Iterable[str],
specification_names: Iterable[str],
status: Optional[Iterable[RecordStatusEnum]],
include: Optional[Iterable[str]],
) -> List[Tuple[str, str, BaseRecord]]:
"""
Fetches records from the remote server
This does not do any checking for existing records, but is used to actually
request the data from the server.
Note: This function does not do any batching w.r.t. server API limits. It is expected that is done
before this function is called. This function also does not look up records in the cache, but does
attach the cache to the records.
Note: Records are not returned in any particular order
Parameters
----------
entry_names
Names of the entries whose records to fetch. If None, fetch all entries
specification_names
Names of the specifications whose records to fetch. If None, fetch all specifications
status
Fetch only records with these statuses
include
Additional fields/data to include when fetch the entry
Returns
-------
:
List of tuples (entry_name, spec_name, record)
"""
if not (entry_names and specification_names):
return []
# First, we need the corresponding entries and specifications
self.fetch_entries(entry_names)
self.fetch_specifications(specification_names)
body = DatasetFetchRecordsBody(entry_names=entry_names, specification_names=specification_names, status=status)
record_info = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records/bulkFetch",
List[Tuple[str, str, int]], # (entry_name, spec_name, record_id)
body=body,
)
record_ids = [x[2] for x in record_info]
# This function always fetches, so force_fetch = True
# But records will be attached to thee cache
records = get_records_with_cache(self._client, self._cache_data, self._record_type, record_ids, include, True)
# Update the locally-stored metadata for these dataset records
# zip(record_info, records) = ((entry_name, spec_name, record_id), record)
update_info = [(ename, sname, r) for (ename, sname, _), r in zip(record_info, records)]
self._cache_data.update_dataset_records(update_info)
return update_info
def _internal_update_records(
self,
entry_names: Iterable[str],
specification_names: Iterable[str],
status: Optional[Iterable[RecordStatusEnum]],
include: Optional[Iterable[str]],
) -> List[Tuple[str, str, BaseRecord]]:
"""
Update local record information if the record has been modified on the server
Parameters
----------
entry_names
Names of the entries whose records to update. If None, fetch all entries
specification_names
Names of the specifications whose records to update. If None, fetch all specifications
status
Update records that have this status on the server. If None, update records with any status on the server
include
Additional fields/data to include when fetch the entry
"""
if not (entry_names and specification_names):
return []
# Returns list of tuple (entry name, spec_name, id, status, modified_on) of records
# we have our local cache
updateable_record_info = self._cache_data.get_dataset_record_info(entry_names, specification_names, None)
# print(f"UPDATEABLE RECORDS: {len(updateable_record_info)}")
if not updateable_record_info:
return []
batch_size = math.ceil(self._client.api_limits["get_records"] / 4)
server_modified_time: Dict[int, datetime] = {} # record_id -> modified time on server
# Find out which records have been updated on the server
for record_info_batch in chunk_iterable(updateable_record_info, batch_size):
record_id_batch = [x[2] for x in record_info_batch]
# Do a raw call to the records/bulkGet endpoint. This allows us to only get
# the 'modified_on' and 'status' fields
server_record_info = self._client.make_request(
"post",
f"api/v1/records/bulkGet",
List[Dict[str, Any]],
body=CommonBulkGetBody(ids=record_id_batch, include=["id", "modified_on", "status"]),
)
# Too lazy to look up how pydantic stores datetime, so use pydantic to parse it
for sri in server_record_info:
# Only store if the status on the server matches what the caller wants
if status is None or sri["status"] in status:
server_modified_time[sri["id"]] = pydantic.parse_obj_as(datetime, sri["modified_on"])
# Which ones need to be fully updated
need_updating: Dict[str, List[str]] = {} # key is specification, value is list of entry names
for entry_name, spec_name, record_id, _, local_mtime in updateable_record_info:
server_mtime = server_modified_time.get(record_id, None)
# Perhaps the record doesn't exist on the server anymore or something
if server_mtime is None:
continue
if local_mtime < server_mtime:
need_updating.setdefault(spec_name, [])
need_updating[spec_name].append(entry_name)
# Update from the server one spec at a time
# print(f"Updated on server: {len(needs_updating)}")
updated_records = []
for spec_name, entries_to_update in need_updating.items():
for entries_batch in chunk_iterable(entries_to_update, batch_size):
# Updates dataset record metadata if needed
r = self._internal_fetch_records(entries_batch, [spec_name], None, include)
updated_records.extend(r)
return updated_records
[docs]
def fetch_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
status: Optional[Union[RecordStatusEnum, Iterable[RecordStatusEnum]]] = None,
include: Optional[Iterable[str]] = None,
fetch_updated: bool = True,
force_refetch: bool = False,
):
"""
Fetches record information from the remote server, storing it internally
By default, this function will only fetch records that have not been fetch previously.
If `force_refetch` is True, then this will always fetch the records.
Parameters
----------
entry_names
Names of the entries whose records to fetch. If None, fetch all entries
specification_names
Names of the specifications whose records to fetch. If None, fetch all specifications
status
Fetch only records with these statuses
include
Additional fields to include in the returned record
fetch_updated
Fetch any records that exist locally but have been updated on the server
force_refetch
If true, fetch data from the server even if it already exists locally
"""
self.assert_is_not_view()
self.assert_online()
# Reload entry names if we are forcing refetching
if force_refetch:
self.fetch_entry_names()
self.fetch_specifications()
status = make_list(status)
# if not specified, do all entries and specs
# we make copies because fetching records can modify _specification_names and _entry_names members
if entry_names is None:
entry_names = self.entry_names.copy()
else:
entry_names = make_list(entry_names).copy()
if specification_names is None:
specification_names = self.specification_names.copy()
else:
specification_names = make_list(specification_names).copy()
# Determine the number of entries in each batch
# Assume there are many more entries than specifications, and that
# everything has been submitted
batch_size: int = math.ceil(self._client.api_limits["get_records"])
n_batches = math.ceil(len(entry_names) / batch_size)
# Do all entries for one spec. This simplifies things, especially with handling
# existing or update-able records
for spec_name in specification_names:
for entry_names_batch in tqdm(chunk_iterable(entry_names, batch_size), total=n_batches, disable=None):
records_batch = []
# Handle existing records that need to be updated
if force_refetch:
r = self._internal_fetch_records(entry_names_batch, [spec_name], status, include)
records_batch.extend(r)
else:
missing_entries = entry_names_batch.copy()
if fetch_updated:
updated_records = self._internal_update_records(missing_entries, [spec_name], status, include)
records_batch.extend(updated_records)
# what wasn't updated
updated_entries = [x for x, _, _ in updated_records]
missing_entries = [e for e in entry_names_batch if e not in updated_entries]
# Check if we have any cached records
cached_records = self._cache_data.get_dataset_records(missing_entries, [spec_name])
for _, _, cr in cached_records:
cr.propagate_client(self._client)
records_batch.extend(cached_records)
# what we need to fetch from the server
cached_entries = [x[0] for x in cached_records]
missing_entries = [e for e in missing_entries if e not in cached_entries]
fetched_records = self._internal_fetch_records(missing_entries, [spec_name], status, include)
records_batch.extend(fetched_records)
# Write the record batch to the cache at once. Also marks the records as clean (no need to writeback)
self._cache_data.update_records([r for _, _, r in records_batch])
[docs]
def get_record(
self,
entry_name: str,
specification_name: str,
include: Optional[Iterable[str]] = None,
fetch_updated: bool = True,
force_refetch: bool = False,
) -> Optional[BaseRecord]:
"""
Obtain a calculation record related to this dataset
The record will be automatically fetched from the remote server if needed.
If a record does not exist for this entry and specification, None is returned
"""
if self.is_view:
fetch_updated = False
force_refetch = False
record = None
if force_refetch:
records = self._internal_fetch_records([entry_name], [specification_name], None, include)
record = records[0][2]
elif fetch_updated:
records = self._internal_update_records([entry_name], [specification_name], None, include)
if records:
record = records[0][2]
if record is None:
# Attempt to get from cache
record = self._cache_data.get_dataset_record(entry_name, specification_name)
if record is None and not self.is_view:
# not in cache
records = self._internal_fetch_records([entry_name], [specification_name], None, include)
record = records[0][2]
if record is not None and self._client is not None:
record.propagate_client(self._client)
return record
[docs]
def iterate_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
status: Optional[Union[RecordStatusEnum, Iterable[RecordStatusEnum]]] = None,
include: Optional[Iterable[str]] = None,
fetch_updated: bool = True,
force_refetch: bool = False,
):
#########################################################
# We duplicate a little bit of fetch_records here, since
# we want to yield in the middle
#########################################################
if self.is_view:
fetch_updated = False
force_refetch = False
# Get an up-to-date list of entry names and specifications
# Nothing to fetch if this is a view
if force_refetch:
self.fetch_entry_names()
self.fetch_specifications()
status = make_list(status)
# if not specified, do all entries and specs
# we make copies because fetching records can modify _specification_names and _entry_names members
if entry_names is None:
entry_names = self.entry_names.copy()
else:
entry_names = make_list(entry_names).copy()
if specification_names is None:
specification_names = self.specification_names.copy()
else:
specification_names = make_list(specification_names).copy()
if self.is_view:
for spec_name in specification_names:
for entry_names_batch in chunk_iterable(entry_names, 125):
record_data = self._cache_data.get_dataset_records(entry_names_batch, [spec_name], status)
for e, s, r in record_data:
yield e, s, r
else:
batch_size: int = math.ceil(self._client.api_limits["get_records"])
for spec_name in specification_names:
for entry_names_batch in chunk_iterable(entry_names, batch_size):
records_batch = []
# Handle existing records that need to be updated
if force_refetch:
r = self._internal_fetch_records(entry_names_batch, [spec_name], status, include)
records_batch.extend(r)
else:
missing_entries = entry_names_batch.copy()
if fetch_updated:
updated_records = self._internal_update_records(
missing_entries, [spec_name], status, include
)
records_batch.extend(updated_records)
# what wasn't updated
updated_entries = [x for x, _, _ in updated_records]
missing_entries = [e for e in entry_names_batch if e not in updated_entries]
# Check if we have any cached records
cached_records = self._cache_data.get_dataset_records(missing_entries, [spec_name])
for _, _, cr in cached_records:
cr.propagate_client(self._client)
records_batch.extend(cached_records)
# what we need to fetch from the server
cached_entries = [x[0] for x in cached_records]
missing_entries = [e for e in missing_entries if e not in cached_entries]
fetched_records = self._internal_fetch_records(missing_entries, [spec_name], status, include)
records_batch.extend(fetched_records)
# Let the writeback mechanism handle writing to the cache
for e, s, r in records_batch:
if status is None or r.status in status:
yield e, s, r
[docs]
def remove_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
delete_records: bool = False,
) -> DeleteMetadata:
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRemoveRecordsBody(
entry_names=entry_names,
specification_names=specification_names,
delete_records=delete_records,
)
ret = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records/bulkDelete",
None,
body=body,
)
if delete_records:
record_info = self._cache_data.get_dataset_records(entry_names, specification_names)
self._cache_data.delete_records([r.id for _, _, r in record_info])
self._cache_data.delete_dataset_records(entry_names, specification_names)
return ret
[docs]
def modify_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
new_tag: Optional[str] = None,
new_priority: Optional[PriorityEnum] = None,
new_comment: Optional[str] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordModifyBody(
entry_names=entry_names,
specification_names=specification_names,
tag=new_tag,
priority=new_priority,
comment=new_comment,
)
ret = self._client.make_request(
"patch",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def reset_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordModifyBody(
entry_names=entry_names,
specification_names=specification_names,
status=RecordStatusEnum.waiting,
)
ret = self._client.make_request(
"patch",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def cancel_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordModifyBody(
entry_names=entry_names,
specification_names=specification_names,
status=RecordStatusEnum.cancelled,
)
ret = self._client.make_request(
"patch",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def uncancel_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordRevertBody(
entry_names=entry_names,
specification_names=specification_names,
revert_status=RecordStatusEnum.cancelled,
)
ret = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records/revert",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def invalidate_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordModifyBody(
entry_names=entry_names,
specification_names=specification_names,
status=RecordStatusEnum.invalid,
)
ret = self._client.make_request(
"patch",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def uninvalidate_records(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
*,
refetch_records: bool = False,
):
self.assert_is_not_view()
self.assert_online()
entry_names = make_list(entry_names)
specification_names = make_list(specification_names)
body = DatasetRecordRevertBody(
entry_names=entry_names,
specification_names=specification_names,
revert_status=RecordStatusEnum.invalid,
)
ret = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records/revert",
None,
body=body,
)
if refetch_records:
self.fetch_records(entry_names, specification_names, force_refetch=True)
return ret
[docs]
def compile_values(
self,
value_call: Callable,
value_names: Union[Sequence[str], str] = "value",
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
unpack: bool = False,
) -> "DataFrame":
"""
Compile values from records into a pandas DataFrame.
Parameters
-----------
value_call
Function to call on each record to extract the desired value. Must return a scalar value or
a sequence of values if 'unpack' is set to True.
value_names
Column name(s) for the extracted value(s). If a string is provided and multiple values are
returned by 'value_call', columns are named by appending an index to this string. If a list
of strings is provided, it must match the length of the sequence returned by 'value_call'.
Default is "value".
entry_names
Entry names to filter records. If not provided, considers all entries.
specification_names
Specification names to filter records. If not provided, considers all specifications.
unpack
If True, unpack the sequence of values returned by 'value_call' into separate columns.
Default is False.
Returns
--------
pandas.DataFrame
A multi-index DataFrame where each row corresponds to an entry. Each column corresponds has a top level
index as a specification, and a second level index as the appropriate value name.
Values are extracted from records using 'value_call'.
Raises
-------
ValueError
If the length of 'value_names' does not match the number of values returned by 'value_call' when
'unpack' is set to True.
Notes
------
1. The DataFrame is structured such that the rows are entries and columns are specifications.
2. If 'unpack' is True, the function assumes 'value_call' returns a sequence of values that need
to be distributed across columns in the resulting DataFrame. 'value_call' should always return the
same number of values for each record if unpack is True.
"""
import pandas as pd
def _data_generator(unpack=False):
for entry_name, spec_name, record in self.iterate_records(
entry_names=entry_names,
specification_names=specification_names,
status=RecordStatusEnum.complete,
fetch_updated=True,
force_refetch=False,
):
if unpack:
yield entry_name, spec_name, *value_call(record)
else:
yield entry_name, spec_name, value_call(record)
def _check_first():
gen = _data_generator()
_, _, first_value = next(gen)
return first_value
first_value = _check_first()
if unpack and isinstance(first_value, Sequence) and not isinstance(first_value, str):
if isinstance(value_names, str):
column_names = [value_names + str(i) for i in range(len(first_value))]
else:
if len(first_value) != len(value_names):
raise ValueError(
"Number of column names must match number of values returned by provided function."
)
column_names = value_names
df = pd.DataFrame(_data_generator(unpack=True), columns=("entry", "specification", *column_names))
else:
column_names = [value_names]
df = pd.DataFrame(_data_generator(), columns=("entry", "specification", value_names))
return_val = df.pivot(index="entry", columns="specification", values=column_names)
# Make specification top level index.
return return_val.swaplevel(axis=1)
[docs]
def get_properties_df(self, properties_list: Sequence[str]) -> "DataFrame":
"""
Retrieve a DataFrame populated with the specified properties from dataset records.
This function uses the provided list of property names to extract corresponding
values from each record's properties. It returns a DataFrame where rows represent
each record. Each column corresponds has a top level index as a specification,
and a second level index as the appropriate value name. Columns with all
NaN values are dropped.
Parameters:
-----------
properties_list
List of property names to retrieve from the records.
Returns:
--------
pandas.DataFrame
A DataFrame populated with the specified properties for each record.
"""
# create lambda function to get all properties at once
extract_properties = lambda x: [x.properties.get(property_name) for property_name in properties_list]
# retrieve values.
result = self.compile_values(extract_properties, value_names=properties_list, unpack=True)
# Drop columns with all nan values. This will occur if a property that is not part of a
# specification is requested.
result.dropna(how="all", axis=1, inplace=True)
return result
##############################
# Caching
##############################
[docs]
def refresh_cache(
self,
entry_names: Optional[Union[str, Iterable[str]]] = None,
specification_names: Optional[Union[str, Iterable[str]]] = None,
):
"""
Refreshes some information in the cache with information on the server
This can be used to fix some inconsistencies in the cache without deleting and starting over.
For example, this can fix instances where the record attached to a given entry & specification
has changed (new record id) due to renaming specifications and entries, or via remove_records followed
by a submit without duplicate checking.
This will also fetch any updated records
Parameters
----------
entry_names
Names of the entries whose records to fetch. If None, fetch all entries
specification_names
Names of the specifications whose records to fetch. If None, fetch all specifications
"""
self.assert_is_not_view()
self.assert_online()
# Reload all entry names and specifications
self.fetch_entry_names()
self.fetch_specification_names()
# Delete anything in the cache that doesn't correspond to these entries/specs
local_specifications = self._cache_data.get_specification_names()
local_entries = self._cache_data.get_entry_names()
deleted_specifications = set(self.specification_names) - set(local_specifications)
deleted_entries = set(self.entry_names) - set(local_entries)
for spec_name in deleted_specifications:
self._cache_data.delete_specification(spec_name)
for entry_name in deleted_entries:
self._cache_data.delete_entry(entry_name)
###############################
# Now for the actual fetching
# if not specified, do all entries and specs
if entry_names is None:
entry_names = self.entry_names
else:
entry_names = make_list(entry_names)
if specification_names is None:
specification_names = self.specification_names
else:
specification_names = make_list(specification_names)
# Determine the number of entries in each batch
# Assume there are many more entries than specifications, and that
# everything has been submitted
# Divide by 4 to go easy on the server
batch_size: int = math.ceil(self._client.api_limits["get_records"] / 4)
# Do all entries for one spec. This simplifies things, especially with handling
# existing or update-able records
for spec_name in specification_names:
# Fetch the specification itself
self.fetch_specifications(spec_name)
for entry_names_batch in chunk_iterable(entry_names, batch_size):
# Fetch the entries themselves
self.fetch_entries(entry_names_batch, force_refetch=True)
# What info do we have stored locally
# (entry_name, spec_name, record_id)
cached_records = self._cache_data.get_dataset_record_info(entry_names_batch, [spec_name], None)
# Get the record info corresponding to this specification & these entries
body = DatasetFetchRecordsBody(entry_names=entry_names_batch, specification_names=specification_names)
server_ds_records = self._client.make_request(
"post",
f"api/v1/datasets/{self.dataset_type}/{self.id}/records/bulkFetch",
List[Tuple[str, str, int]], # (entry_name, spec_name, record_id)
body=body,
)
# Also get basic information about the records themselves
server_ds_records_map = {(e, s): rid for e, s, rid in server_ds_records}
server_record_ids = list(set(server_ds_records_map.values()))
# Do a raw call to the records/bulkGet endpoint. This allows us to only get
# the 'modified_on' and 'status' fields
server_record_info = self._client.make_request(
"post",
f"api/v1/records/bulkGet",
List[Dict[str, Any]],
body=CommonBulkGetBody(ids=server_record_ids, include=["modified_on", "status"]),
)
server_record_info_map = {r["id"]: r for r in server_record_info}
# Check for any different record_ids, or for deleted records
records_tofetch = []
for ename, sname, record_id, status, modified_on in cached_records:
server_ds_record_id = server_ds_records_map.get((ename, sname), None)
# If record does not exist on the server or has a different id, delete it locally from the cache
if server_ds_record_id is None or record_id != server_ds_record_id:
self._cache_data.delete_dataset_record(ename, sname)
records_tofetch.append(server_ds_record_id)
continue
# This is guaranteed to exist, right?
rinfo = server_record_info_map[record_id]
rinfo_modified = pydantic.parse_obj_as(datetime, rinfo["mod"])
if rinfo_modified > modified_on or rinfo["status"] != status:
records_tofetch.append(record_id) # same as server_ds_record_id
##############################
# Contributed values
##############################
[docs]
def fetch_contributed_values(self):
self.assert_is_not_view()
self.assert_online()
self._contributed_values = self._client.make_request(
"get",
f"api/v1/datasets/{self.id}/contributed_values",
Optional[Dict[str, ContributedValues]],
)
@property
def contributed_values(self) -> Dict[str, ContributedValues]:
if not self.contributed_values:
self.fetch_contributed_values()
return self.contributed_values
class DatasetAddBody(RestModelBase):
name: str
description: str
tagline: str
tags: List[str]
group: str
provenance: Dict[str, Any]
visibility: bool
default_tag: str
default_priority: PriorityEnum
metadata: Dict[str, Any]
owner_group: Optional[str]
existing_ok: bool = False
class DatasetModifyMetadata(RestModelBase):
name: str
description: str
tags: List[str]
tagline: str
group: str
visibility: bool
provenance: Optional[Dict[str, Any]]
metadata: Optional[Dict[str, Any]]
default_tag: str
default_priority: PriorityEnum
class DatasetQueryModel(RestModelBase):
dataset_type: Optional[str] = None
dataset_name: Optional[str] = None
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class DatasetFetchSpecificationBody(RestModelBase):
names: List[str]
missing_ok: bool = False
class DatasetFetchEntryBody(RestModelBase):
names: List[str]
missing_ok: bool = False
class DatasetDeleteStrBody(RestModelBase):
names: List[str]
delete_records: bool = False
class DatasetRemoveRecordsBody(RestModelBase):
entry_names: List[str]
specification_names: List[str]
delete_records: bool = False
class DatasetDeleteParams(RestModelBase):
delete_records: bool = False
@validator("delete_records", pre=True)
def validate_lists(cls, v):
return validate_list_to_single(v)
class DatasetFetchRecordsBody(RestModelBase):
entry_names: List[str]
specification_names: List[str]
status: Optional[List[RecordStatusEnum]] = None
class DatasetSubmitBody(RestModelBase):
entry_names: Optional[List[str]] = None
specification_names: Optional[List[str]] = None
tag: Optional[str] = None
priority: Optional[PriorityEnum] = None
owner_group: Optional[str] = None
find_existing: bool = True
class DatasetRecordModifyBody(RestModelBase):
entry_names: Optional[List[str]] = None
specification_names: Optional[List[str]] = None
status: Optional[RecordStatusEnum] = None
priority: Optional[PriorityEnum] = None
tag: Optional[str] = None
comment: Optional[str] = None
class DatasetRecordRevertBody(RestModelBase):
entry_names: Optional[List[str]] = None
specification_names: Optional[List[str]] = None
revert_status: RecordStatusEnum = None
class DatasetQueryRecords(RestModelBase):
record_id: List[int]
dataset_type: Optional[List[str]] = None
class DatasetDeleteEntryBody(RestModelBase):
names: List[str]
delete_records: bool = False
class DatasetDeleteSpecificationBody(RestModelBase):
names: List[str]
delete_records: bool = False
def dataset_from_dict(data: Dict[str, Any], client: Any, cache_data: Optional[DatasetCache] = None) -> BaseDataset:
"""
Create a dataset object from a datamodel
This determines the appropriate dataset class (deriving from BaseDataset)
and creates an instance of that class.
This works if the data is a datamodel object already or a dictionary
"""
dataset_type = data["dataset_type"]
cls = BaseDataset.get_subclass(dataset_type)
return cls(client=client, cache_data=cache_data, **data)
def load_dataset_view(file_path: str) -> BaseDataset:
# Reads this as a read-only "view"
ds_meta = read_dataset_metadata(file_path)
ds_type = BaseDataset.get_subclass(ds_meta["dataset_type"])
file_path = os.path.abspath(file_path)
cache_uri = f"file:{file_path}?mode=ro"
ds_cache = DatasetCache(cache_uri, True, ds_type)
# Views never have a client attached
return dataset_from_dict(ds_meta, None, cache_data=ds_cache)
def dataset_from_cache(file_path: str) -> BaseDataset:
# Keep old name around
return load_dataset_view(file_path)
def create_dataset_view(
client: PortalClient,
dataset_id: int,
file_path: str,
include: Optional[Iterable[str]] = None,
overwrite: bool = False,
):
file_path = os.path.abspath(file_path)
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise ValueError(f"Path {file_path} exists and is not a file")
if os.path.exists(file_path) and not overwrite:
raise ValueError(f"File {file_path} exists and overwrite is False")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# Manually get it, because we want to use a different cache file
ds_dict = client.make_request("get", f"api/v1/datasets/{dataset_id}", Dict[str, Any])
ds_cache = DatasetCache(f"file:{file_path}", False, BaseDataset.get_subclass(ds_dict["dataset_type"]))
ds = dataset_from_dict(ds_dict, client, ds_cache)
ds.fetch_records(include=include, force_refetch=True)