Source code for xenonpy.inverse.iqspr.modifier

#  Copyright (c) 2021. yoshida-lab. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.

import re
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd
from rdkit import Chem
from tqdm import tqdm

from xenonpy.inverse.base import BaseProposal, ProposalError


[docs]class GetProbError(ProposalError): def __init__(self, tmp_str, i_b, i_r): self.tmp_str = tmp_str self.iB = i_b self.iR = i_r self.old_smi = None super().__init__('get_prob: %s not found in NGram, iB=%i, iR=%i' % (tmp_str, i_b, i_r))
[docs]class MolConvertError(ProposalError): def __init__(self, new_smi): self.new_smi = new_smi self.old_smi = None super().__init__('can not convert %s to Mol' % new_smi)
[docs]class NGramTrainingError(ProposalError): def __init__(self, error, smi): self.old_smi = smi super().__init__('training failed for %s, because of <%s>: %s' % (smi, error.__class__.__name__, error))
[docs]class NGram(BaseProposal): def __init__(self, *, ngram_table=None, sample_order=(1, 10), del_range=(1, 10), min_len=1, max_len=1000, reorder_prob=0): """ N-Garm Parameters ---------- ngram_table: NGram table NGram table for modify SMILES. sample_order: tuple[int, int] or int range of order of ngram table used during proposal, when given int, sample_order = (1, int) del_range: tuple[int, int] or int range of random deletion of SMILES string during proposal, when given int, del_range = (1, int) min_len: int minimum length of the extended SMILES, shall be smaller than the lower bound of the sample_order max_len: int max length of the extended SMILES to be terminated from continuing modification reorder_prob: float probability of the SMILES being reordered during proposal """ self.sample_order = sample_order self.reorder_prob = reorder_prob self.min_len = min_len self.max_len = max_len self.del_range = del_range if ngram_table is None: self._table = None self._train_order = None else: self._table = deepcopy(ngram_table) self._train_order = (1, len(ngram_table)) self._fit_sample_order() self._fit_min_len() @property def sample_order(self): return self._sample_order @sample_order.setter def sample_order(self, val): if isinstance(val, int): self._sample_order = (1, val) elif isinstance(val, tuple): self._sample_order = val elif isinstance(val, (list, np.array, pd.Series)): self._sample_order = (val[0], val[1]) else: raise TypeError( 'please input a <tuple> of two <int> or a single <int> for sample_order') if self._sample_order[0] < 1: raise RuntimeError('Min sample_order must be greater than 0') if self._sample_order[1] < self._sample_order[0]: raise RuntimeError('Min sample_order must not be smaller than max sample_order') @property def reorder_prob(self): return self._reorder_prob @reorder_prob.setter def reorder_prob(self, val): if isinstance(val, (int, float)): self._reorder_prob = val else: raise TypeError('please input a <float> for reorder_prob') @property def min_len(self): return self._min_len @min_len.setter def min_len(self, val): if isinstance(val, int): self._min_len = val else: raise TypeError('please input a <int> for min_len') @property def max_len(self): return self._max_len @max_len.setter def max_len(self, val): if isinstance(val, int): self._max_len = val else: raise TypeError('please input a <int> for max_len') @property def del_range(self): return self._del_range @del_range.setter def del_range(self, val): if isinstance(val, int): self._del_range = (1, val) elif isinstance(val, tuple): self._del_range = val elif isinstance(val, (list, np.array, pd.Series)): self._del_range = (val[0], val[1]) else: raise TypeError('please input a <tuple> of two <int> or a single <int> for del_range') if self._del_range[1] < self._del_range[0]: raise RuntimeError('Min del_range must not be smaller than max del_range') def _fit_sample_order(self): if self._train_order and self._train_order[1] < self.sample_order[1]: warnings.warn( 'max <sample_order>: %s is greater than max <train_order>: %s,' 'max <sample_order> will be reduced to max <train_order>' % (self.sample_order[1], self._train_order[1]), RuntimeWarning) self.sample_order = (self.sample_order[0], self._train_order[1]) if self._train_order and self._train_order[0] > self.sample_order[0]: if self._train_order[0] > self.sample_order[1]: warnings.warn( 'max <sample_order>: %s is smaller than min <train_order>: %s,' '<sample_order> will be replaced by the values of <train_order>' % (self.sample_order[1], self._train_order[0]), RuntimeWarning) self.sample_order = (self._train_order[0], self._train_order[1]) else: warnings.warn( 'min <sample_order>: %s is smaller than min <train_order>: %s,' 'min <sample_order> will be increased to min <train_order>' % (self.sample_order[0], self._train_order[0]), RuntimeWarning) self.sample_order = (self._train_order[0], self.sample_order[1]) def _fit_min_len(self): if self.sample_order[0] > self.min_len: warnings.warn( 'min <sample_order>: %s is greater than min_len: %s,' 'min_len will be increased to min <sample_order>' % (self.sample_order[0], self.min_len), RuntimeWarning) self.min_len = self.sample_order[0]
[docs] def on_errors(self, error): """ Parameters ---------- error: ProposalError Error object. Returns ------- """ if isinstance(error, GetProbError): return error.old_smi if isinstance(error, MolConvertError): return error.old_smi if isinstance(error, NGramTrainingError): pass
@property def ngram_table(self): return deepcopy(self._table) @ngram_table.setter def ngram_table(self, value): self._table = deepcopy(value)
[docs] def modify(self, ext_smi): # reorder for a given probability if np.random.random() < self.reorder_prob: ext_smi = self.reorder_esmi(ext_smi) # number of deletion (randomly pick from given range) n_del = np.random.randint(self.del_range[0], self.del_range[1] + 1) # first delete then add ext_smi = self.del_char(ext_smi, min(n_del + 1, len(ext_smi) - self.min_len)) # at least leave min_len char # add until reaching '!' or a given max value for i in range(self.max_len - len(ext_smi)): ext_smi, _ = self.sample_next_char(ext_smi) if ext_smi['esmi'].iloc[-1] == '!': return ext_smi # stop when hitting '!', assume must be valid SMILES # check incomplete esmi ext_smi = self.validator(ext_smi) # fill in the '!' new_pd_row = { 'esmi': '!', 'n_br': 0, 'n_ring': 0, 'substr': ext_smi['substr'].iloc[-1] + ['!'] } warnings.warn('Extended SMILES hits max length', RuntimeWarning) return ext_smi.append(new_pd_row, ignore_index=True)
[docs] @classmethod def smi2list(cls, smiles): # smi_pat = r'(-\[.*?\]|=\[.*?\]|#\[.*?\]|\[.*?\]|-Br|=Br|#Br|-Cl|=Cl|#Cl|Br|Cl|-.|=.|#.|\%[0-9][0-9]|\w|\W)' # smi_pat = r'(=\[.*?\]|#\[.*?\]|\[.*?\]|=Br|#Br|=Cl|#Cl|Br|Cl|=.|#.|\%[0-9][0-9]|\w|\W)' # smi_pat = r'(\[.*?\]|Br|Cl|\%[0-9][0-9]|\w|\W)' smi_pat = r'(\[.*?\]|Br|Cl|(?<=%)[0-9][0-9]|\w|\W)' # smi_list = list(filter(None, re.split(smi_pat, smiles))) smi_list = list(filter(lambda x: not ((x == "") or (x == "%")), re.split(smi_pat, smiles))) # combine bond with next token only if the next token isn't a number # assume SMILES does not end with a bonding character! tmp_idx = [ i for i, x in enumerate(smi_list) if ((x in "-=#") and (not smi_list[i + 1].isdigit())) ] if len(tmp_idx) > 0: for i in tmp_idx: smi_list[i + 1] = smi_list[i] + smi_list[i + 1] smi_list = np.delete(smi_list, tmp_idx).tolist() return smi_list
[docs] @classmethod def smi2esmi(cls, smi): smi_list = cls.smi2list(smi) esmi_list = smi_list + ['!'] substr_list = [] # list of all contracted substrings (include current char.) # list of whether open branch exist at current character position (include current char.) br_list = [] # list of number of open ring at current character position (include current char.) ring_list = [] v_substr = [] # list of temporary contracted substrings v_ringn = [] # list of numbering of open rings c_br = 0 # tracking open branch steps for recording contracted substrings n_br = 0 # tracking number of open branches tmp_ss = [] # list of current contracted substring for i in range(len(esmi_list)): if c_br == 2: v_substr.append(deepcopy(tmp_ss)) # contracted substring added w/o ')' c_br = 0 elif c_br == 1: c_br = 2 if esmi_list[i] == '(': c_br = 1 n_br += 1 elif esmi_list[i] == ')': tmp_ss = deepcopy(v_substr[-1]) # retrieve contracted substring added w/o ')' v_substr.pop() n_br -= 1 elif esmi_list[i].isdigit(): esmi_list[i] = int(esmi_list[i]) if esmi_list[i] in v_ringn: esmi_list[i] = v_ringn.index(esmi_list[i]) v_ringn.pop(esmi_list[i]) else: v_ringn.insert(0, esmi_list[i]) esmi_list[i] = '&' tmp_ss.append(esmi_list[i]) substr_list.append(deepcopy(tmp_ss)) br_list.append(n_br) ring_list.append(len(v_ringn)) return pd.DataFrame({ 'esmi': esmi_list, 'n_br': br_list, 'n_ring': ring_list, 'substr': substr_list })
# may add error check here in the future?
[docs] @classmethod def esmi2smi(cls, ext_smi): smi_list = ext_smi['esmi'].tolist() num_open = [] num_unused = list(range(99, 0, -1)) for i in range(len(smi_list)): if smi_list[i] == '&': if num_unused[-1] > 9: smi_list[i] = ''.join(['%', str(num_unused[-1])]) else: smi_list[i] = str(num_unused[-1]) num_open.insert(0, num_unused[-1]) num_unused.pop() elif isinstance(smi_list[i], int): tmp = int(smi_list[i]) if num_open[tmp] > 9: smi_list[i] = ''.join(['%', str(num_open[tmp])]) else: smi_list[i] = str(num_open[tmp]) num_unused.append(num_open[tmp]) num_open.pop(tmp) if smi_list[-1] == "!": # cover cases of incomplete esmi_pd smi_list.pop() # remove the final '!' return ''.join(smi_list)
[docs] def remove_table(self, max_order=None): """ Remove estimators from estimator set. Parameters ---------- max_order: int max order to be left in the table, the rest is removed. """ if max_order: tmp = self._train_order[1] - max_order if tmp < 1: warnings.warn('Nothing removed', RuntimeWarning) else: self._table = self._table[:-tmp] self._train_order = (self._train_order[0], max_order) else: self._table = None self._train_order = None
[docs] def fit(self, smiles, *, train_order=(1, 10)): """ Parameters ---------- smiles: list[str] SMILES for training. train_order: tuple[int, int] or int range of order when train a NGram table, when given int, train_order = (1, int), and train_order[0] must be > 0 Returns ------- """ def _fit_one(ext_smi): for iB in [0, 1]: # index for open/closed branches char. position, remove last row for '!' idx_B = ext_smi.iloc[:-1].index[(ext_smi['n_br'].iloc[:-1] > 0) == iB] list_R = ext_smi['n_ring'][idx_B].unique().tolist() if len(list_R) > 0: # expand list of dataframe for max. num-of-ring + 1 if len(self._table[0][iB]) < (max(list_R) + 1): for ii in range(len(self._table)): self._table[ii][iB].extend([ pd.DataFrame() for i in range((max(list_R) + 1) - len(self._table[ii][iB])) ]) for iR in list_R: # index for num-of-open-ring char. pos. idx_R = idx_B[ext_smi['n_ring'][idx_B] == iR] # shift one down for 'next character given substring' tar_char = ext_smi['esmi'][idx_R + 1].tolist() tar_substr = ext_smi['substr'][idx_R].tolist() for iO in range(self._train_order[0] - 1, self._train_order[1]): # index for char with substring length not less than order idx_O = [x for x in range(len(tar_substr)) if len(tar_substr[x]) > iO] for iC in idx_O: if not tar_char[iC] in self._table[iO][iB][iR].columns.tolist(): self._table[iO][iB][iR][tar_char[iC]] = 0 tmp_row = str(tar_substr[iC][-(iO + 1):]) if tmp_row not in self._table[iO][iB][iR].index.tolist(): self._table[iO][iB][iR].loc[tmp_row] = 0 # somehow 'at' not ok with mixed char and int column names self._table[iO][iB][iR].loc[tmp_row, tar_char[iC]] += 1 if self._table: raise RuntimeError('NGram table existed.' 'If you want to re-train the table,' 'please use `remove_table()` method first.') if isinstance(train_order, int): tmp_train_order = (1, train_order) elif isinstance(train_order, tuple): tmp_train_order = train_order elif isinstance(train_order, (list, np.array, pd.Series)): tmp_train_order = (train_order[0], train_order[1]) else: raise TypeError('please input a <tuple> of two <int> or a single <int> for train_order') if tmp_train_order[0] < 1: raise RuntimeError('Min train_order must be greater than 0') if tmp_train_order[1] < tmp_train_order[0]: raise RuntimeError('Min train_order must not be smaller than max train_order') self._train_order = tmp_train_order self._table = [[[], []] for _ in range(self._train_order[1])] self._fit_sample_order() self._fit_min_len() for smi in tqdm(smiles): try: _fit_one(self.smi2esmi(smi)) except Exception as e: warnings.warn('NGram training failed for %s' % smi, RuntimeWarning) e = NGramTrainingError(e, smi) self.on_errors(e) return self
# get probability vector for sampling next character, return character list and corresponding probability in numpy.array (normalized) # may cause error if empty string list is fed into 'tmp_str' # Warning: maybe can reduce the input of iB and iR - directly input the reduced list of self._ngram_tab (?) # Warning: may need to update this function with bisection search for faster speed (?) # Warning: may need to add worst case that no pattern found at all?
[docs] def get_prob(self, tmp_str, iB, iR): # right now we use back-off method, an alternative is Kneser–Nay smoothing cand_char = [] cand_prob = 1 iB = int(iB) for iO in range(self.sample_order[1] - 1, self.sample_order[0] - 2, -1): # if (len(tmp_str) > iO) & (str(tmp_str[-(iO + 1):]) in self._table[iO][iB][iR].index.tolist()): if len(tmp_str) > iO and str( tmp_str[-(iO + 1):]) in self._table[iO][iB][iR].index.tolist(): cand_char = self._table[iO][iB][iR].columns.tolist() cand_prob = np.array(self._table[iO][iB][iR].loc[str(tmp_str[-(iO + 1):])]) break if len(cand_char) == 0: warnings.warn('get_prob: %s not found in NGram, iB=%i, iR=%i' % (tmp_str, iB, iR), RuntimeWarning) raise GetProbError(tmp_str, iB, iR) return cand_char, cand_prob / np.sum(cand_prob)
# get the next character, return the probability value
[docs] def sample_next_char(self, ext_smi): iB = ext_smi['n_br'].iloc[-1] > 0 iR = ext_smi['n_ring'].iloc[-1] cand_char, cand_prob = self.get_prob(ext_smi['substr'].iloc[-1], iB, iR) # here we assume cand_char is not empty idx = np.random.choice(range(len(cand_char)), p=cand_prob) next_char = cand_char[idx] ext_smi = self.add_char(ext_smi, next_char) return ext_smi, cand_prob[idx]
[docs] @classmethod def add_char(cls, ext_smi, next_char): new_pd_row = ext_smi.iloc[-1] new_pd_row.at['substr'] = new_pd_row['substr'] + [next_char] new_pd_row.at['esmi'] = next_char if next_char == '(': new_pd_row.at['n_br'] += 1 elif next_char == ')': new_pd_row.at['n_br'] -= 1 # assume '(' must exist before if the extended SMILES is valid! (will fail if violated) # idx = next((x for x in range(len(new_pd_row['substr'])-1,-1,-1) if new_pd_row['substr'][x] == '('), None) # find index of the last unclosed '(' tmp_c = 1 for x in range(len(new_pd_row['substr']) - 2, -1, -1): # exclude the already added "next_char" if new_pd_row['substr'][x] == '(': tmp_c -= 1 elif new_pd_row['substr'][x] == ')': tmp_c += 1 if tmp_c == 0: idx = x break # assume no '()' and '((' pattern that is not valid/possible in SMILES new_pd_row.at['substr'] = new_pd_row['substr'][:(idx + 2)] + [')'] elif next_char == '&': new_pd_row.at['n_ring'] += 1 elif isinstance(next_char, int): new_pd_row.at['n_ring'] -= 1 return ext_smi.append(new_pd_row, ignore_index=True)
[docs] @classmethod def del_char(cls, ext_smi, n_char): if n_char > 0: return ext_smi[:-n_char] else: return ext_smi
# need to make sure esmi_pd is a completed SMILES to use this function # todo: kekuleSmiles?
[docs] @classmethod def reorder_esmi(cls, ext_smi): # convert back to SMILES first, then to rdkit MOL mol = Chem.MolFromSmiles(cls.esmi2smi(ext_smi)) idx = np.random.choice(range(mol.GetNumAtoms())).item() # currently assume kekuleSmiles=True, i.e., no small letters but with ':' for aromatic rings ext_smi = cls.smi2esmi(Chem.MolToSmiles(mol, rootedAtAtom=idx)) return ext_smi
[docs] def validator(self, ext_smi): # delete all ending '(' or '&' for i in range(len(ext_smi)): if not ((ext_smi['esmi'].iloc[-1] == '(') or (ext_smi['esmi'].iloc[-1] == '&')): break ext_smi = self.del_char(ext_smi, 1) # delete or fill in ring closing flag_ring = ext_smi['n_ring'].iloc[-1] > 0 for i in range(len(ext_smi)): # max to double the length of current SMILES if flag_ring and (np.random.random() < 0.7): # 50/50 for adding two new char. # add a character ext_smi, _ = self.sample_next_char(ext_smi) flag_ring = ext_smi['n_ring'].iloc[-1] > 0 else: break if flag_ring: # prepare for delete (1st letter shall not be '&') tmp_idx = ext_smi.iloc[1:].index tmp_count = np.array(ext_smi['n_ring'].iloc[1:]) - np.array(ext_smi['n_ring'].iloc[:-1]) num_open = tmp_idx[tmp_count == 1].values.tolist() num_open.reverse() num_close = tmp_idx[tmp_count == -1].values.tolist() idx_pop = [] for i in num_close: idx_pop.append(ext_smi['esmi'][i]) for ii, i in enumerate(idx_pop): ext_smi['esmi'][num_close[ii]] += sum([x < i for x in idx_pop[ii + 1:]]) - i num_open.pop(i) # delete all irrelevant rows and reconstruct esmi ext_smi = self.smi2esmi( self.esmi2smi(ext_smi.drop(ext_smi.index[num_open]).reset_index(drop=True))) ext_smi = ext_smi.iloc[:-1] # remove the '!' # delete ':' that are not inside a ring # tmp_idx = esmi_pd.index[(esmi_pd['esmi'] == ':') & (esmi_pd['n_ring'] < 1)] # if len(tmp_idx) > 0: # esmi_pd = smi2esmi(esmi2smi(esmi_pd.drop(tmp_idx).reset_index(drop=True))) # esmi_pd = esmi_pd.iloc[:-1] # remove the '!' # fill in branch closing (last letter shall not be '(') for i in range(ext_smi['n_br'].iloc[-1]): ext_smi = self.add_char(ext_smi, ')') return ext_smi
[docs] def proposal(self, smiles): """ Propose new SMILES based on the given SMILES. Make sure you always check the train_order against sample_order before using the proposal! Parameters ---------- smiles: list of SMILES Given SMILES for modification. Returns ------- new_smiles: list of SMILES The proposed SMILES from the given SMILES. """ new_smis = [] for i, smi in enumerate(smiles): ext_smi = self.smi2esmi(smi) try: new_ext_smi = self.modify(ext_smi) new_smi = self.esmi2smi(new_ext_smi) if Chem.MolFromSmiles(new_smi) is None: warnings.warn('can not convert %s to Mol' % new_smi, RuntimeWarning) raise MolConvertError(new_smi) new_smis.append(new_smi) except ProposalError as e: e.old_smi = smi new_smi = self.on_errors(e) new_smis.append(new_smi) except Exception as e: raise e return new_smis
def _merge_table(self, ngram_tab, weight=1): """ Merge with a given NGram table Parameters ---------- ngram_tab: NGram the table in the given NGram class variable will be merged to the table in self weight: double a scalar to scale the frequency in the given NGram table Returns ------- tmp_n_gram: NGram merged NGram tables """ self._train_order = (min(self._train_order[0], ngram_tab._train_order[0]), max(self._train_order[1], ngram_tab._train_order[1])) self._fit_sample_order() self._fit_min_len() n_gram_tab1 = self._table # do not use deepcopy here n_gram_tab2 = ngram_tab.ngram_table # default deepcopy used w = weight ord1 = len(n_gram_tab1) ord2 = len(n_gram_tab2) Bc1 = len(n_gram_tab1[0][0]) Bc2 = len(n_gram_tab2[0][0]) Bo1 = len(n_gram_tab1[0][1]) Bo2 = len(n_gram_tab2[0][1]) # fix the number of ring mis-match first if Bc1 < Bc2: for ii in range(ord1): n_gram_tab1[ii][0].extend([pd.DataFrame() for _ in range(Bc2 - Bc1)]) elif Bc1 > Bc2: for ii in range(ord2): n_gram_tab2[ii][0].extend([pd.DataFrame() for _ in range(Bc1 - Bc2)]) if Bo1 < Bo2: for ii in range(ord1): n_gram_tab1[ii][1].extend([pd.DataFrame() for _ in range(Bo2 - Bo1)]) elif Bo1 > Bo2: for ii in range(ord2): n_gram_tab2[ii][1].extend([pd.DataFrame() for _ in range(Bo1 - Bo2)]) # fix order mis-match if ord2 > ord1: n_gram_tab1.extend(n_gram_tab2[ord1:]) # combine overlapped order (weighted on tab2) for i in range(min(ord1, ord2)): for j in range(len(n_gram_tab1[i])): for k in range(len(n_gram_tab1[i][j])): n_gram_tab1[i][j][k] = n_gram_tab1[i][j][k].add(w * n_gram_tab2[i][j][k], fill_value=0).fillna(0)
[docs] def merge_table(self, *ngram_tab: 'NGram', weight=1, overwrite=True): """ Merge with a given NGram table Parameters ---------- ngram_tab the table(s) in the given NGram class variable(s) will be merged to the table in self weight: int/float or list/tuple/np.array/pd.Series[int/float] a scalar/vector to scale the frequency in the given NGram table to be merged, must have the same length as ngram_tab overwrite: boolean overwrite the original table (self) or not, do not recommend to be False (may have memory issue) Returns ------- tmp_n_gram: NGram merged NGram tables """ if not np.all([isinstance(x, NGram) for x in ngram_tab]): raise TypeError('each element in the input must be <NGram>') if isinstance(weight, (int, float)): weight = np.repeat(weight, len(ngram_tab)) elif isinstance(weight, (tuple, list, np.array, pd.Series)): if not np.all([isinstance(x, (int, float)) for x in weight]): raise TypeError('each element in weight must be <int> or <float>') else: raise TypeError('weight must be <int> or <float> or a list of them') if overwrite: tmp_n_gram = self # do not use deepcopy here else: tmp_n_gram = deepcopy(self) for i, tab in enumerate(ngram_tab): tmp_n_gram._merge_table(ngram_tab=tab, weight=weight[i]) return tmp_n_gram
[docs] def split_table(self, cut_order): """ Split NGram table into two Parameters ---------- cut_order: int split NGram table between cut_order and cut_order+1 Returns ------- n_gram1: NGram n_gram2: NGram """ n_gram1 = deepcopy(self) n_gram1.remove_table(max_order=cut_order) n_gram1._fit_sample_order() n_gram1._fit_min_len() n_gram2 = deepcopy(self) for iB in [0, 1]: for ii in range(cut_order): n_gram2._table[ii][iB] = [ pd.DataFrame() for _ in range(len(n_gram2._table[ii][iB])) ] n_gram2._train_order = (cut_order + 1, self._train_order[1]) n_gram2._fit_sample_order() n_gram2._fit_min_len() return n_gram1, n_gram2