Source code for molecules.data.dataloaders.dataset
import os, sys
import numpy as np
[docs]class ContactMapDataset:
"""Data handling for Numpy stored contact maps"""
def __init__(self, path, name=None):
self.path = path
self.name = name
def __repr__(self):
if self.name and sys.version[0] == 3:
identifier = 'Data handler for {self.name} contact maps'
return identifier
elif self.name:
identifier = 'Data handler for {} contact maps'.format(self.name)
return identifier
return 'Data handler for contact maps'
[docs] def load_data(self, shape=None):
"""Load numpy array data.
Parameters
----------
shape : tuple, optional
Shape of the data. Format: (H x W x C)
Returns
-------
X_train : np.ndarray
Training set.
X_test : np.ndarray
Test set.
"""
train = os.path.join(self.path, 'train')
test = os.path.join(self.path, 'test')
X_train = np.load(train)
X_test = np.load(test)
if shape:
X_train = X_train.reshape((-1, shape[0], shape[1], shape[2]))
X_test = X_test.reshape((-1, shape[0], shape[1], shape[2]))
return X_train, X_test