Source code for pyfemtet.opt.interface._surrogate_model_interface.base_surrogate_interface
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from pyfemtet.opt.history import *
from pyfemtet.opt.interface import AbstractFEMInterface
from pyfemtet._i18n import _
if TYPE_CHECKING:
from pyfemtet.opt.optimizer._base_optimizer import (
AbstractOptimizer,
GlobalOptimizationData,
OptimizationDataPerFEM,
DIRECTION
)
__all__ = [
'AbstractSurrogateModelInterfaceBase',
]
[docs]
class AbstractSurrogateModelInterfaceBase(AbstractFEMInterface):
current_obj_values: dict[str, float]
train_history: History
def __init__(
self,
history_path: str | None = None,
train_history: History | None = None,
_output_directions: (
Sequence[DIRECTION]
| dict[str, DIRECTION]
| dict[int, DIRECTION]
| None
) = None
):
self._output_directions = _output_directions
# history_path が与えられた場合、train_history をコンストラクトする
if history_path is not None:
train_history = History()
train_history.load_csv(history_path, with_finalize=True)
assert train_history is not None
self.train_history = train_history
self.current_obj_values = {}
@property
def object_pass_to_fun(self):
return self.current_obj_values
[docs]
def contact_to_optimizer(
self,
opt: AbstractOptimizer,
global_data: GlobalOptimizationData,
ctx: OptimizationDataPerFEM,
):
# output_directions で指定された分を
# ctx に対して add_objective する。
# directions を正規化
name_and_directions: dict[str, DIRECTION]
if self._output_directions is None:
name_and_directions = {}
elif isinstance(self._output_directions, dict):
name_and_directions = {}
obj_names = self.train_history.obj_names
for obj_name_or_index, direction in self._output_directions.items():
if isinstance(obj_name_or_index, int):
obj_name = obj_names[obj_name_or_index]
else:
obj_name = obj_name_or_index
name_and_directions.update({obj_name: direction})
else:
obj_names = self.train_history.obj_names
if len(self._output_directions) != len(obj_names):
raise ValueError(_(
en_message='The length of _output_directions passed as a list '
'must be same with that of the history\'s objective '
'names.',
jp_message='_output_directions をリストで渡す場合は'
'その長さが history の目的関数数と一致して'
'いなければなりません。'
))
name_and_directions = {
obj_name: direction
for obj_name, direction
in zip(obj_names, self._output_directions)
}
# global に紐づいていると objective_pass_to_fun が Sequence になるので
# global に登録されたものの内 train に含まれるものは ctx に移動
obj_names_to_remove = set()
for obj_name, obj in global_data.objectives.items():
if obj_name in self.train_history.obj_names:
obj_names_to_remove.add(obj_name)
ctx.objectives.update({obj_name: obj})
for obj_name in obj_names_to_remove:
global_data.objectives.pop(obj_name)
# directions に含まれるものをすべて ctx に追加または上書き
for obj_name, direction in name_and_directions.items():
# validation
if obj_name not in self.train_history.obj_names:
raise ValueError(_(
en_message="The objective name {obj_name} is "
"not in the train_history's objectives: "
"{obj_names}",
jp_message="目的関数名 {obj_name} は "
"train_history の目的関数: "
"{obj_names} に含まれていません。",
obj_name=obj_name,
obj_names=self.train_history.obj_names,
))
def dummy(*args, **kwargs):
assert False
ctx.add_objective(obj_name, dummy, direction, supress_duplicated_name_check=True)
# ctx の目的関数のうち train に含まれるものを差し替え
# directions の目的変数名を validation しているので漏れはないはず
for obj_name, obj in ctx.objectives.items():
if obj_name in self.train_history.obj_names:
# global 由来のものは何が入っているかわからないので
# 必要な変数を確実に上書き
obj.direction = name_and_directions.get(obj_name, obj.direction)
obj.args = tuple()
obj.kwargs = dict()
obj.fun = lambda obj_values, obj_name_=obj_name: obj_values[obj_name_]
# これが呼ばれた後に optimizer の同期はされないので
# ここで同期する
opt._initialize_objectives()
opt.objectives.update(ctx.objectives)
opt.objectives.update(global_data.objectives)
def _check_using_fem(self, fun: callable) -> bool:
return False
def _check_param_and_raise(self, prm_name) -> None:
if prm_name not in self.train_history.prm_names:
raise KeyError(f'Parameter name {prm_name} is not in '
f'training input {self.train_history.prm_names}.')