Source code for unimol_tools.predict

# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function

import os
import json

import joblib
import numpy as np
from omegaconf import DictConfig, OmegaConf

from .data import DataHub
from .models import NNModel
from .tasks import Trainer
from .utils import YamlHandler, logger


[docs] class MolPredict(object): """A :class:`MolPredict` class is responsible for interface of predicting process of molecular data."""
[docs] def __init__(self, load_model=None, cfg: DictConfig | None = None): """ Initialize a :class:`MolPredict` class. :param load_model: str, default=None, path of model to load. """ if cfg is not None: cfg_dict = OmegaConf.to_container(cfg, resolve=True) load_model = cfg_dict.get("load_model", load_model) if not load_model: raise ValueError("load_model is empty") self.load_model = load_model config_path = os.path.join(load_model, 'config.yaml') self.config = YamlHandler(config_path).read_yaml() self.config.target_cols = self.config.target_cols.split(',') self.task = self.config.task self.target_cols = self.config.target_cols
[docs] def predict(self, data, save_path=None, metrics='none'): """ Predict molecular data. :param data: str, list, numpy, pandas.Series, pandas.DataFrame, dict of atoms and coordinates, input data for prediction. - str: path of csv file. - list: list of smiles strings. - numpy.ndarray: numpy array of data. - pandas.Series: series of smiles strings. - pandas.DataFrame: dataframe of data. - dict: dict of atoms and coordinates, e.g. {'atoms': ['C', 'C', 'C'], 'coordinates': [[0, 0, 0], [0, 0, 1], [0, 0, 2]]} :param save_path: str, default=None, path to save predict result. :param metrics: str, default='none', metrics to evaluate model performance. currently support: - classification: auc, auprc, log_loss, acc, f1_score, mcc, precision, recall, cohen_kappa. - regression: mae, pearsonr, spearmanr, mse, r2. - multiclass: log_loss, acc. - multilabel_classification: auc, auprc, log_loss, acc, mcc. - multilabel_regression: mae, mse, r2. :return y_pred: numpy.ndarray, predict result. """ self.save_path = save_path self.config['sdf_save_path'] = save_path if not metrics or metrics != 'none': self.config.metrics = metrics ## load test data self.datahub = DataHub( data=data, is_train=False, save_path=self.load_model, **self.config ) self.config.use_ddp = False self.trainer = Trainer(save_path=self.load_model, **self.config) self.model = NNModel(self.datahub.data, self.trainer, **self.config) self.model.evaluate(self.trainer, self.load_model) y_pred = self.model.cv['test_pred'] scalar = self.datahub.data['target_scaler'] if scalar is not None: y_pred = scalar.inverse_transform(y_pred) df = self.datahub.data['raw_data'].copy() predict_cols = ['predict_' + col for col in self.target_cols] if self.task == 'multiclass' and self.config.multiclass_cnt is not None: prob_cols = ['prob_' + str(i) for i in range(self.config.multiclass_cnt)] df[prob_cols] = y_pred df[predict_cols] = np.argmax(y_pred, axis=1).reshape(-1, 1) elif self.task in ['classification', 'multilabel_classification']: threshold = joblib.load( open(os.path.join(self.load_model, 'threshold.dat'), "rb") ) prob_cols = ['prob_' + col for col in self.target_cols] df[prob_cols] = y_pred df[predict_cols] = (y_pred > threshold).astype(int) else: prob_cols = predict_cols df[predict_cols] = y_pred if self.save_path: os.makedirs(self.save_path, exist_ok=True) if not (df[self.target_cols] == -1.0).all().all(): metrics = self.trainer.metrics.cal_metric( df[self.target_cols].values, df[prob_cols].values ) logger.info("final predict metrics score: \n{}".format(metrics)) if self.save_path: joblib.dump(metrics, os.path.join(self.save_path, 'test_metric.result')) with open(os.path.join(self.save_path, 'test_metric.json'), 'w') as f: json.dump(metrics, f) else: df.drop(self.target_cols, axis=1, inplace=True) if self.save_path: prefix = ( data.split('/')[-1].split('.')[0] if isinstance(data, str) else 'test' ) self.save_predict(df, self.save_path, prefix) logger.info("pipeline finish!") return y_pred
[docs] def save_predict(self, data, dir, prefix): """ Save predict result to csv file. :param data: pandas.DataFrame, predict result. :param dir: str, directory to save predict result. :param prefix: str, prefix of predict result file name. """ run_id = 0 if not os.path.exists(dir): os.makedirs(dir) else: folders = [x for x in os.listdir(dir)] while prefix + f'.predict.{run_id}' + '.csv' in folders: run_id += 1 name = prefix + f'.predict.{run_id}' + '.csv' path = os.path.join(dir, name) data.to_csv(path) logger.info("save predict result to {}".format(path))