Source code for pyfemtet.opt.visualization.plotter.contour_creator

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from pyfemtet.logger import get_module_logger


__all__ = [
    'contour_creator',
]


[docs] def contour_creator(df: pd.DataFrame) -> go.Figure | str: logger = get_module_logger('opt.contour_creator') target_column = tuple(df.columns)[-1] explain_columns = [column for column in df.columns if column != target_column] subplots = make_subplots( rows=len(explain_columns), cols=len(explain_columns), shared_xaxes=True, shared_yaxes=True, row_titles=explain_columns, column_titles=explain_columns, start_cell='bottom-left', ) is_first = True for r, r_key in enumerate(explain_columns): for c, c_key in enumerate(explain_columns): r_dtype = df.dtypes[r_key] c_dtype = df.dtypes[c_key] if not ('float' in r_dtype.name or 'int' in r_dtype.name): logger.error(f'dtype is {r_dtype}. Not implemented.') return 'Not implemented: including categorical parameters.' if not ('float' in c_dtype.name or 'int' in c_dtype.name): logger.error(f'dtype is {c_dtype}. Not implemented.') return 'Not implemented: including categorical parameters.' x = df[c_key] y = df[r_key] z = df[target_column] scatter = go.Scatter( x=x, y=y, mode='markers', marker=go.scatter.Marker( symbol='circle', color='black', size=5, line=dict( color='white', width=1, ) ), name='points (click to switch visibility)', legendgroup='points (click to switch visibility)', showlegend=is_first, ) is_first = False if r == c: pass else: contour = go.Contour( x=x, y=y, z=z, connectgaps=True, name=f'contour of\n{target_column}', colorscale='Turbo', ) subplots.add_trace(contour, row=r + 1, col=c + 1) subplots.add_trace(scatter, row=r + 1, col=c + 1) subplots.update_layout( legend=dict( orientation='h', xanchor='center', x=0.5, yanchor='bottom', y=-0.2, bgcolor='rgba(0, 0, 0, 0.15)', ), # margin=dict(b=50), ) return subplots
if __name__ == '__main__': import numpy as np contour_creator(pd.DataFrame(dict( x1=np.random.rand(20), x2=np.random.rand(20), x3=np.random.rand(20), x4=np.random.rand(20), y=np.random.rand(20), ))).show()