Source code for edafm.visualization
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
def _calc_plot_dim(n, f=0.3):
rows = max(int(np.sqrt(n) - f), 1)
cols = 1
while rows*cols < n:
cols += 1
return rows, cols
[docs]def plot_input(X, constant_range=False, cmap='afmhot'):
'''
Plot single stack of AFM images.
Arguments:
X: np.ndarray of shape (x, y, z). AFM image to plot.
constant_range: Boolean. Whether the different slices should use the same value range or not.
cmap: str or matplotlib colormap. Colormap to use for plotting.
Returns:
matplotlib.pyplot.figure. Figure on which the image was plotted.
'''
rows, cols = _calc_plot_dim(X.shape[-1])
fig = plt.figure(figsize=(3.2*cols,2.5*rows))
vmax = X.max()
vmin = X.min()
for k in range(X.shape[-1]):
fig.add_subplot(rows,cols,k+1)
if constant_range:
plt.imshow(X[:,:,k].T, cmap=cmap, vmin=vmin, vmax=vmax, origin="lower")
else:
plt.imshow(X[:,:,k].T, cmap=cmap, origin="lower")
plt.colorbar()
plt.tight_layout()
return fig
[docs]def make_input_plots(Xs, outdir='./predictions/', start_ind=0, constant_range=False, cmap='afmhot', verbose=1):
'''
Plot multiple AFM image stacks to files 0_input.png, 1_input.png, ... etc.
Arguments:
Xs: list of np.ndarray of shape (batch, x, y, z). Input AFM images to plot.
outdir: str. Directory where images are saved.
start_ind: int. Save index increments by one for each image. The first index is start_ind.
constant_range: Boolean. Whether the different slices should use the same value range or not.
cmap: str or matplotlib colormap. Colormap to use for plotting.
verbose: int 0 or 1. Whether to print output information.
'''
if not os.path.exists(outdir):
os.makedirs(outdir)
img_ind = start_ind
for i in range(Xs[0].shape[0]):
for j in range(len(Xs)):
plot_input(Xs[j][i], constant_range, cmap=cmap)
save_name = f'{img_ind}_input'
if len(Xs) > 1:
save_name += str(j+1)
save_name = os.path.join(outdir, save_name)
save_name += '.png'
plt.savefig(save_name)
plt.close()
if verbose > 0: print(f'Input image saved to {save_name}')
img_ind += 1
[docs]def make_prediction_plots(preds=None, true=None, losses=None, descriptors=None, outdir='./predictions/', start_ind=0, verbose=1):
'''
Plot predictions/references for image descriptors.
Arguments:
preds: list of np.ndarray of shape (batch_size, x_dim, y_dim). Predicted maps.
Each list element corresponds to one descriptor.
true: list of np.ndarray of shape (batch_size, x_dim, y_dim). Reference maps.
Each list element corresponds to one descriptor.
losses: np.ndarray of shape (len(preds), batch_size). Losses for each predictions.
descriptors: list of str. Names of descriptors. The name "ES" causes the coolwarm colormap to be used.
outdir: str. Directory where images are saved.
start_ind: int. Starting index for saved images.
verbose: int 0 or 1. Whether to print output information.
'''
rows = (preds is not None) + (true is not None)
if rows == 0:
raise ValueError('preds and true cannot both be None.')
elif rows == 1:
data = preds if preds is not None else true
else:
assert len(preds) == len(true)
cols = len(preds) if preds is not None else len(true)
if descriptors is not None:
assert len(descriptors) == cols
if not os.path.exists(outdir):
os.makedirs(outdir)
img_ind = start_ind
batch_size = len(preds[0]) if preds is not None else len(true[0])
for j in range(batch_size):
fig, axes = plt.subplots(rows, cols)
fig.set_size_inches(6*cols, 5*rows)
if rows == 1:
axes = np.expand_dims(axes, axis=0)
if cols == 1:
axes = np.expand_dims(axes, axis=1)
for i in range(cols):
top_ax = axes[0,i]
bottom_ax = axes[-1,i]
if rows == 2:
p = preds[i][j]
t = true[i][j]
vmax = np.concatenate([p,t]).max()
vmin = np.concatenate([p,t]).min()
else:
d = data[i][j]
vmax = d.max()
vmin = d.min()
title1 = ''
title2 = ''
cmap = cm.viridis
if descriptors is not None:
descriptor = descriptors[i]
title1 += f'{descriptor} Prediction'
title2 += f'{descriptor} Reference'
if descriptor == 'ES':
vmax = max(abs(vmax), abs(vmin))
vmin = -vmax
cmap = cm.coolwarm
if losses is not None:
title1 += f'\nMSE = {losses[i,j]:.2E}'
if rows == 2:
im1 = top_ax.imshow(p.T, vmax=vmax, vmin=vmin, cmap=cmap, origin='lower')
im2 = bottom_ax.imshow(t.T, vmax=vmax, vmin=vmin, cmap=cmap, origin='lower')
if title1:
top_ax.set_title(title1)
bottom_ax.set_title(title2)
else:
im1 = top_ax.imshow(d.T, vmax=vmax, vmin=vmin, cmap=cmap, origin='lower')
if title1:
title = title1 if preds is not None else title2
top_ax.set_title(title)
for axi in axes[:,i]:
pos = axi.get_position()
pos_new = [pos.x0, pos.y0, 0.8*(pos.x1-pos.x0), pos.y1-pos.y0]
axi.set_position(pos_new)
pos1 = top_ax.get_position()
pos2 = bottom_ax.get_position()
c_pos = [pos1.x1+0.1*(pos1.x1-pos1.x0), pos2.y0, 0.08*(pos1.x1-pos1.x0), pos1.y1-pos2.y0]
cbar_ax = fig.add_axes(c_pos)
fig.colorbar(im1, cax=cbar_ax)
save_name = os.path.join(outdir, f'{img_ind}_pred.png')
plt.savefig(save_name, bbox_inches="tight")
plt.close()
if verbose > 0: print(f'Prediction saved to {save_name}')
img_ind += 1