# Copyright (c) 2021. yoshida-lab. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sb
from typing import Sequence, Union
from xenonpy.datatools import Scaler
from sklearn.preprocessing import power_transform
from sklearn.base import BaseEstimator
from sklearn.preprocessing import minmax_scale
[docs]class DescriptorHeatmap(BaseEstimator):
"""
Heatmap.
"""
def __init__(self,
save=None,
bc=False,
cmap='RdBu',
pivot_kws=None,
method='average',
metric='euclidean',
figsize=None,
row_cluster=False,
col_cluster=True,
row_linkage=None,
col_linkage=None,
row_colors=None,
col_colors=None,
mask=None,
**kwargs):
"""
Parameters
----------
save
bc
pivot_kws
method
metric
figsize
row_cluster
col_cluster
row_linkage
col_linkage
row_colors
col_colors
mask
kwargs
"""
self.cmap = cmap
self.save = save
self.bc = bc
self.pivot_kws = pivot_kws
self.method = method
self.metric = metric
self.figsize = figsize
self.col_cluster = col_cluster
self.row_linkage = row_linkage
self.row_cluster = row_cluster
self.col_linkage = col_linkage
self.row_colors = row_colors
self.col_colors = col_colors
self.mask = mask
self.kwargs = kwargs
self.desc = None
[docs] def fit(self, desc):
scaler = Scaler().min_max()
if self.bc:
scaler = scaler.yeo_johnson()
self.desc = scaler.fit_transform(desc)
return self
[docs] def draw(self,
y: Union[Sequence, pd.Series, None] = None,
name: str = None,
*,
return_sorted_idx: bool = False):
"""
Draw figure.
Parameters
----------
y
Properties values corresponding to samples.
name
Property name. If ``name`` is ``None`` and ``y`` is not a ``pandas.Series``.
No name will be draw in the figure.
return_sorted_idx
If ``True``, return sorted column index of descriptor.
Default ``False``
Returns
-------
idx: np.array
sorted column index if ``return_sorted_idx`` is ``True``
"""
heatmap_ax = sb.clustermap(self.desc,
cmap=self.cmap,
method=self.method,
figsize=self.figsize,
row_cluster=self.row_cluster,
col_cluster=self.col_cluster,
**self.kwargs)
heatmap_ax.cax.set_visible(False)
heatmap_ax.ax_heatmap.yaxis.set_ticks_position('left')
heatmap_ax.ax_heatmap.yaxis.set_label_position('left')
if y is None:
heatmap_ax.ax_col_dendrogram.set_position((0.1, 0.8, 0.9, 0.1))
heatmap_ax.ax_heatmap.set_position((0.1, 0.2, 0.9, 0.6))
else:
heatmap_ax.ax_col_dendrogram.set_position((0.1, 0.8, 0.83, 0.1))
heatmap_ax.ax_heatmap.set_position((0.1, 0.2, 0.84, 0.6))
prop_ax = plt.axes([0.95, 0.2, 0.05, 0.6])
# draw prop
if isinstance(y, pd.Series):
x_ = y.values
name_ = y.name
else:
x_ = np.asarray(y)
name_ = ''
if name is not None:
name_ = name
y_ = np.arange(len(x_))[::-1]
prop_ax.plot(x_, y_, lw=4)
prop_ax.get_yaxis().set_visible(False)
prop_ax.spines['top'].set_visible(False)
prop_ax.spines['right'].set_visible(False)
prop_ax.set_xlabel('{:s}'.format(name_), fontsize='large')
if self.save:
plt.savefig(**self.save)
if return_sorted_idx:
try:
return heatmap_ax.dendrogram_col.reordered_ind
except AttributeError:
return np.arange(heatmap_ax.data2d.shape[1])