Skip to content
Snippets Groups Projects
Commit 5ac7fe55 authored by Florian Ziemen's avatar Florian Ziemen
Browse files

move to serde for serialization

parent 90cbea1a
No related branches found
No related tags found
1 merge request!2draft: Basic dataset
from collections.abc import Sequence
from dataclasses import dataclass
from serde import serde
@serde
@dataclass
class dataset_location:
method: str
......@@ -15,11 +16,12 @@ class dataset_location:
return vars(self)
@serde
@dataclass
class dataset:
name: str
serve: bool
locations: Sequence[dataset_location]
locations: list[dataset_location]
def __post_init__(self):
self.locations = list(self.locations)
......@@ -35,14 +37,3 @@ class dataset:
and all(type(x) is dataset_location for x in self.locations)
and all(x.is_valid() for x in self.locations)
)
def serialize(self):
rd = dict(self.vars())
rd["locations"] = set(x.serialize() for x in self.locations)
return rd
def from_dict(the_dict):
locations = set(dataset_location(**x) for x in the_dict["locations"])
return dataset(
name=the_dict["name"], serve=the_dict["serve"], locations=locations
)
from collections.abc import Sequence
from dataclasses import dataclass
from dataset import dataset
from serde import serde
@serde
@dataclass
class model_setup:
git_repo: str
commit_id: str
@serde
@dataclass
class icon_setup(model_setup):
configure_command: str
......@@ -16,16 +18,17 @@ class icon_setup(model_setup):
host: str
@serde
@dataclass
class experiment:
experiment_id: str
model: str
setup: model_setup
setup: icon_setup | model_setup
point_of_contact: str
copyright_info: str
cite_as: str
description: str
datasets: Sequence[dataset]
datasets: list[dataset]
def __post_init__(self):
self.datasets = list(self.datasets)
......
import dataset
import unittest
from serde.yaml import from_yaml, to_yaml
class test_dataset(unittest.TestCase):
dsl1 = dataset.dataset_location(method="direct", host="local", path="/scratch/")
dsl2 = dataset.dataset_location(method="direct", host="local", path="/work/")
ds = dataset.dataset(name="test_ds", serve=True, locations=(dsl1, dsl2))
def test_dataset(self):
dsl1 = dataset.dataset_location(method="direct", host="local", path="/scratch/")
dsl1 = test_dataset.dsl1
ds = dataset.dataset(name="test_ds", serve=True, locations=(dsl1,))
self.assertTrue(ds.is_valid())
dsl2 = dataset.dataset_location(method="direct", host="local", path="/work/")
ds = dataset.dataset(name="test_ds", serve=True, locations=(dsl1, dsl2))
ds = test_dataset.ds
self.assertTrue(ds.is_valid())
def test_bad_dataset_location(self):
......@@ -42,6 +46,14 @@ class test_dataset(unittest.TestCase):
with self.assertRaises(TypeError):
bad_location = dataset.dataset(name=None, serve=True, locations=(None,))
def test_to_from_yaml(self):
ds = test_dataset.ds
with open("test.yaml", "w") as of:
of.write(to_yaml(ds))
with open("test.yaml", "r") as infile:
obj = from_yaml(dataset.dataset, infile.read())
self.assertEqual(obj, ds)
if __name__ == "__main__":
unittest.main()
......@@ -2,6 +2,7 @@
import experiment
import unittest
from serde.yaml import from_yaml, to_yaml
class test_experiment(unittest.TestCase):
......@@ -12,25 +13,16 @@ class test_experiment(unittest.TestCase):
git_repo="gitlab.dkrz.de",
commit_id="abc123",
)
def test_experiment(self):
setup = experiment.icon_setup(
configure_command="configure",
config="sensible",
host="levante",
git_repo="gitlab.dkrz.de",
commit_id="abc123",
)
experiment.experiment(
experiment_id="test",
model="icon",
setup=setup,
point_of_contact="flo",
copyright_info="cc-by-4.0 Flo",
cite_as="Flo",
description="test experiment",
datasets=[],
)
experiment = experiment.experiment(
experiment_id="test",
model="icon",
setup=setup,
point_of_contact="flo",
copyright_info="cc-by-4.0 Flo",
cite_as="Flo",
description="test experiment",
datasets=[],
)
def test_experiment_fail(self):
with self.assertRaises(TypeError):
......@@ -66,6 +58,14 @@ class test_experiment(unittest.TestCase):
datasets=("BAD", "DATASET"), # !
)
def test_to_from_yaml(self):
test_exp = test_experiment.experiment
with open("test.yaml", "w") as of:
of.write(to_yaml(test_exp))
with open("test.yaml", "r") as infile:
obj = from_yaml(experiment.experiment, infile.read())
self.assertEqual(obj, test_exp)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment