# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function
import os
import json
import joblib
import numpy as np
from omegaconf import DictConfig, OmegaConf
from .data import DataHub
from .models import NNModel
from .tasks import Trainer
from .utils import YamlHandler, logger
[docs]
class MolPredict(object):
"""A :class:`MolPredict` class is responsible for interface of predicting process of molecular data."""
[docs]
def __init__(self, load_model=None, cfg: DictConfig | None = None):
"""
Initialize a :class:`MolPredict` class.
:param load_model: str, default=None, path of model to load.
"""
if cfg is not None:
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
load_model = cfg_dict.get("load_model", load_model)
if not load_model:
raise ValueError("load_model is empty")
self.load_model = load_model
config_path = os.path.join(load_model, 'config.yaml')
self.config = YamlHandler(config_path).read_yaml()
self.config.target_cols = self.config.target_cols.split(',')
self.task = self.config.task
self.target_cols = self.config.target_cols
[docs]
def predict(self, data, save_path=None, metrics='none'):
"""
Predict molecular data.
:param data: str, list, numpy, pandas.Series, pandas.DataFrame, dict of atoms and coordinates, input data for prediction.
- str: path of csv file.
- list: list of smiles strings.
- numpy.ndarray: numpy array of data.
- pandas.Series: series of smiles strings.
- pandas.DataFrame: dataframe of data.
- dict: dict of atoms and coordinates, e.g. {'atoms': ['C', 'C', 'C'], 'coordinates': [[0, 0, 0], [0, 0, 1], [0, 0, 2]]}
:param save_path: str, default=None, path to save predict result.
:param metrics: str, default='none', metrics to evaluate model performance.
currently support:
- classification: auc, auprc, log_loss, acc, f1_score, mcc, precision, recall, cohen_kappa.
- regression: mae, pearsonr, spearmanr, mse, r2.
- multiclass: log_loss, acc.
- multilabel_classification: auc, auprc, log_loss, acc, mcc.
- multilabel_regression: mae, mse, r2.
:return y_pred: numpy.ndarray, predict result.
"""
self.save_path = save_path
self.config['sdf_save_path'] = save_path
if not metrics or metrics != 'none':
self.config.metrics = metrics
## load test data
self.datahub = DataHub(
data=data, is_train=False, save_path=self.load_model, **self.config
)
self.config.use_ddp = False
self.trainer = Trainer(save_path=self.load_model, **self.config)
self.model = NNModel(self.datahub.data, self.trainer, **self.config)
self.model.evaluate(self.trainer, self.load_model)
y_pred = self.model.cv['test_pred']
scalar = self.datahub.data['target_scaler']
if scalar is not None:
y_pred = scalar.inverse_transform(y_pred)
df = self.datahub.data['raw_data'].copy()
predict_cols = ['predict_' + col for col in self.target_cols]
if self.task == 'multiclass' and self.config.multiclass_cnt is not None:
prob_cols = ['prob_' + str(i) for i in range(self.config.multiclass_cnt)]
df[prob_cols] = y_pred
df[predict_cols] = np.argmax(y_pred, axis=1).reshape(-1, 1)
elif self.task in ['classification', 'multilabel_classification']:
threshold = joblib.load(
open(os.path.join(self.load_model, 'threshold.dat'), "rb")
)
prob_cols = ['prob_' + col for col in self.target_cols]
df[prob_cols] = y_pred
df[predict_cols] = (y_pred > threshold).astype(int)
else:
prob_cols = predict_cols
df[predict_cols] = y_pred
if self.save_path:
os.makedirs(self.save_path, exist_ok=True)
if not (df[self.target_cols] == -1.0).all().all():
metrics = self.trainer.metrics.cal_metric(
df[self.target_cols].values, df[prob_cols].values
)
logger.info("final predict metrics score: \n{}".format(metrics))
if self.save_path:
joblib.dump(metrics, os.path.join(self.save_path, 'test_metric.result'))
with open(os.path.join(self.save_path, 'test_metric.json'), 'w') as f:
json.dump(metrics, f)
else:
df.drop(self.target_cols, axis=1, inplace=True)
if self.save_path:
prefix = (
data.split('/')[-1].split('.')[0] if isinstance(data, str) else 'test'
)
self.save_predict(df, self.save_path, prefix)
logger.info("pipeline finish!")
return y_pred
[docs]
def save_predict(self, data, dir, prefix):
"""
Save predict result to csv file.
:param data: pandas.DataFrame, predict result.
:param dir: str, directory to save predict result.
:param prefix: str, prefix of predict result file name.
"""
run_id = 0
if not os.path.exists(dir):
os.makedirs(dir)
else:
folders = [x for x in os.listdir(dir)]
while prefix + f'.predict.{run_id}' + '.csv' in folders:
run_id += 1
name = prefix + f'.predict.{run_id}' + '.csv'
path = os.path.join(dir, name)
data.to_csv(path)
logger.info("save predict result to {}".format(path))