from pathlib import Path
from typing import Any, Optional, Union, Callable, List
__all__ = ["save", "register_saver"]
# Savers and extension mapping
SAVERS = {}
EXTENSION_MAPPING = {}
[docs]
def save(data: Any, file_path: Union[str, Path], saver: Optional[str] = None, **kwargs):
"""
Save data to a file using the appropriate saver.
This function saves data to various file formats by automatically determining the appropriate saver based on the file extension. You can also specify the saver explicitly if desired.
Supported Savers
----------------
- "pickle": For .pkl files.
- "joblib": For .joblib files.
- "pandas": For .csv, .parquet, .xlsx, .xls, .feather files.
- "json": For .json files.
- "str": For .txt, .html, .log, .md, .rst files.
- "hdf5": For .h5, .hdf5, .hdf files.
- "numpy": For .npz, .npy files.
- "pillow": For image files (.jpg, .jpeg, .png, .bmp, .gif, .tiff, .tif, .webp).
- "pytorch": For PyTorch model files (.pt, .pth).
- "yaml": For .yaml, .yml files.
- "ini": For .ini, .cfg files.
- "matlab": For .mat files.
- "audio": For audio files (.wav, .mp3, .flac, .ogg).
- "video": For video files (.mp4, .avi, .mov, .mkv).
Parameters
----------
data : Any
The data to be saved. The type of data should match the saver being used.
file_path : Union[str, Path]
The path to the file where the data should be saved. The file extension will be used to determine the appropriate saver if not specified.
saver : Optional[str], default=None
The saver type to use. If not provided, it will be inferred from the file extension.
kwargs : dict
Additional keyword arguments to pass to the saver function.
Returns
-------
None
Raises
------
ValueError
If the file extension or specified saver is not supported, or if the data type does not match the expected type.
ImportError
If the required library for the specified saver is not installed.
Examples
--------
Saving a DataFrame to a CSV file:
.. code-block:: python
import pandas as pd
from dmf.io import save
df = pd.DataFrame({"a": [1, 2, 3]})
save(df, "data.csv")
Saving an image using Pillow:
.. code-block:: python
from PIL import Image
from dmf.io import save
img = Image.new("RGB", (100, 100), color="red")
save(img, "image.png")
Saving a NumPy array to an NPZ file:
.. code-block:: python
import numpy as np
from dmf.io import save
arr = np.array([1, 2, 3])
save(arr, "data.npz")
"""
file_path = Path(file_path)
ext = file_path.suffix.lstrip(".").lower()
if saver and saver not in SAVERS:
raise ValueError(f"Saver '{saver}' is not supported. "
f"Use one of {list(SAVERS.keys())}.")
elif not saver:
saver = EXTENSION_MAPPING.get(ext, None)
if not saver:
raise ValueError(
f"File extension '{ext}' is not supported. "
f"Supported extensions: {list(EXTENSION_MAPPING.keys())}. "
"Please specify a supported saver."
)
saver_func = SAVERS[saver]
return saver_func(data, file_path, **kwargs)
def register_saver(saver_name: str, extensions: List[str]):
"""
Decorator to register a custom saver.
Parameters
----------
saver_name : str
The name of the saver (must be unique).
extensions : List[str]
The list of file extensions that the saver should handle (without leading dot).
"""
def decorator(saver_function: Callable):
# Register the saver function
if saver_name in SAVERS:
raise ValueError(f"Saver '{saver_name}' is already registered.")
SAVERS[saver_name] = saver_function
# Register the extensions
for extension in extensions:
EXTENSION_MAPPING[extension] = saver_name
return saver_function
return decorator
@register_saver("pickle", ["pkl", "pickle"])
def save_pickle(data: Any, file_path: Path, **kwargs):
"""Save data using the pickle saver."""
import pickle
with open(file_path, "wb") as file:
pickle.dump(data, file, **kwargs)
@register_saver("joblib", ["joblib"])
def save_joblib(data: Any, file_path: Path, **kwargs):
"""Save data using the joblib saver."""
try:
import joblib
except ImportError:
raise ImportError("joblib package is required for joblib saving. "
"Install it using `pip install joblib`.")
joblib.dump(data, file_path, **kwargs)
@register_saver("hdf5", ["h5", "hdf5", "hdf"])
def save_hdf5(data: Any, file_path: Path, dataset_name: str = "dataset", **kwargs):
"""
Save data using the HDF5 saver.
Parameters
----------
data : Any
The data to save. Can be a dictionary of arrays or a single array.
file_path : Path
The path to the HDF5 file.
dataset_name : str, optional
The name of the dataset if `data` is not a dictionary. Default is "dataset".
kwargs : dict
Additional keyword arguments to pass to the h5py dataset creation.
Raises
------
ValueError
If the data type is not supported.
"""
try:
import h5py
except ImportError:
raise ImportError("h5py package is required for HDF5 saving. "
"Install it using `pip install h5py`.")
with h5py.File(file_path, "w") as file:
if isinstance(data, dict):
for key, value in data.items():
file.create_dataset(key, data=value, **kwargs)
else:
file.create_dataset(dataset_name, data=data, **kwargs)
@register_saver("json", ["json"])
def save_json(data: Any, file_path: Path, **kwargs):
"""Save data using the json saver."""
import json
with open(file_path, "w") as file:
json.dump(data, file, **kwargs)
@register_saver("str", ["txt", "html", "log", "md", "rst"])
def save_str(data, file_path: Path, **kwargs):
"""Save data using the txt saver."""
with open(file_path, "w") as file:
file.write(str(data), **kwargs)
@register_saver("numpy", ["npz", "npy"])
def save_numpy(data: Any, file_path: Path, **kwargs):
"""Save data using the numpy saver."""
try:
import numpy as np
except ImportError:
raise ImportError("numpy package is required for numpy saving. "
"Install it using `pip install numpy`.")
# Check if data is a torch.Tensor by checking for 'cpu' and 'numpy' methods
if hasattr(data, 'cpu') and hasattr(data, 'numpy'):
data = data.cpu().numpy()
# If data is not already a NumPy array, convert it
if not isinstance(data, np.ndarray):
data = np.array(data)
ext = file_path.suffix.lstrip(".").lower()
if ext == "npz":
if not isinstance(data, dict):
raise ValueError("NPZ saver expects data to be a dictionary of arrays.")
np.savez(file_path, **data, **kwargs)
elif ext == "npy":
np.save(file_path, data, **kwargs)
else:
raise ValueError(f"Extension {ext} is not supported for numpy saving. "
f"Use one of {EXTENSION_MAPPING.keys()} or use directly the numpy saver.")
@register_saver("pandas", ["csv", "parquet", "xlsx", "xls", "feather"])
def save_pandas(data: Any, file_path: Path, **kwargs):
"""Save data using the pandas saver."""
import pandas as pd
data = pd.DataFrame(data)
ext = file_path.suffix.lstrip(".").lower()
if ext == "csv":
data.to_csv(file_path, **kwargs)
elif ext == "parquet":
data.to_parquet(file_path, **kwargs)
elif ext == "xlsx" or ext == "xls":
data.to_excel(file_path, **kwargs)
elif ext == "feather":
data.to_feather(file_path, **kwargs)
else:
raise ValueError(f"Extension {ext} is not supported for pandas saving. "
f"Use one of {EXTENSION_MAPPING.keys()} or use directly the pandas saver.")
@register_saver("pillow", ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "tif", "webp"])
def save_pillow(data: Any, file_path: Path, **kwargs):
"""Save data using the pillow saver."""
try:
from PIL import Image
except ImportError:
raise ImportError("Pillow package is required for pillow saving. "
"Install it using `pip install pillow`.")
if not isinstance(data, Image.Image):
raise ValueError("Pillow saver expects data to be an instance of PIL.Image.")
data.save(file_path, **kwargs)
@register_saver("pytorch", ["pt", "pth"])
def save_pytorch(data: Any, file_path: Path, **kwargs):
"""Save data using the pytorch saver."""
try:
import torch
except ImportError:
raise ImportError("torch package is required for pytorch saving. "
"Install it using `pip install torch`.")
torch.save(data, file_path, **kwargs)
@register_saver("yaml", ["yaml", "yml"])
def save_yaml(data: Any, file_path: Path, **kwargs):
"""Save data using the YAML saver."""
try:
import yaml
except ImportError:
raise ImportError("PyYAML is required to save .yaml files. "
"Install it using `pip install pyyaml`.")
with open(file_path, "w") as file:
yaml.safe_dump(data, file, **kwargs)
@register_saver("ini", ["ini", "cfg"])
def save_ini(data: Any, file_path: Path, **kwargs):
"""Save data using the INI saver."""
if not isinstance(data, dict):
raise ValueError("INI saver expects data to be a dictionary.")
import configparser
config = configparser.ConfigParser()
for section, params in data.items():
config[section] = params
with open(file_path, "w") as file:
config.write(file)
@register_saver("matlab", ["mat"])
def save_matlab(data: Any, file_path: Path, **kwargs):
"""Save data using the MATLAB saver."""
if not isinstance(data, dict):
raise ValueError("MATLAB saver expects data to be a dictionary.")
try:
import scipy.io
except ImportError:
raise ImportError("scipy.io is required to save .mat files. "
"Install it using `pip install scipy`.")
scipy.io.savemat(file_path, data, **kwargs)
@register_saver("audio", ["wav", "mp3", "flac", "ogg"])
def save_audio(data: Any, file_path: Path, **kwargs):
"""Save data using the audio saver."""
try:
import soundfile as sf
except ImportError:
raise ImportError("soundfile is required to save audio files. "
"Install it using `pip install soundfile`.")
sf.write(file_path, data, **kwargs)
@register_saver("video", ["mp4", "avi", "mov", "mkv"])
def save_video(data: Any, file_path: Path, **kwargs):
"""Save data using the video saver."""
from ..video.video_writer import write_video
write_video(file_path=file_path, frames=data, **kwargs)