from pathlib import Path
import pickle
from neer_match.matching_model import DLMatchingModel, NSMatchingModel
from neer_match.similarity_map import SimilarityMap
import tensorflow as tf
from tensorflow.keras import layers as _layers
from graphviz import Digraph
from typing import Dict, List
import typing
import shutil
import sys
[docs]
class Model:
"""
A class for saving and loading matching models.
Methods
-------
save(model, target_directory, name):
Save the specified model to a target directory.
load(model_directory):
Load a model from a given directory.
"""
[docs]
@staticmethod
def save(
model: typing.Union["DLMatchingModel", "NSMatchingModel"],
target_directory: Path,
name: str,
save_architecture: bool = False,
) -> None:
"""
Save the model to a specified directory.
Parameters
----------
model : DLMatchingModel or NSMatchingModel
The model to be saved.
target_directory : Path
The directory where the model should be saved.
name : str
Name of the model directory.
"""
target_directory = Path(target_directory) / name / "model"
if target_directory.exists():
replace = input(
f"Directory '{target_directory}' already exists. Replace the old model? (y/n): "
).strip().lower()
if replace == "y":
shutil.rmtree(target_directory)
print(f"Old model at '{target_directory}' has been replaced.")
elif replace == "n":
print("Execution halted as per user request.")
sys.exit(0)
else:
print("Invalid input. Please type 'y' or 'n'. Aborting operation.")
return
target_directory.mkdir(parents=True, exist_ok=True)
# --- Build composite similarity info ---
# Use the original instructions stored in the SimilarityMap.
# We assume model.similarity_map.instructions is a dict: { field: [metric1, metric2, ...], ... }
instructions = model.similarity_map.instructions
fields = list(instructions.keys())
association_sizes = model.similarity_map.association_sizes() # aggregated sizes per field
composite_similarity_info = {}
for i, field in enumerate(fields):
agg_size = association_sizes[i]
metrics = instructions[field] # list of metric names as originally provided
composite_similarity_info[field] = {
"metrics": metrics,
"aggregated_size": agg_size,
"per_metric_size": agg_size // len(metrics)
}
# --- Save model initialization parameters from the record pair network ---
model_params = {
"initial_feature_width_scales": model.record_pair_network.initial_feature_width_scales,
"feature_depths": model.record_pair_network.feature_depths,
"initial_record_width_scale": model.record_pair_network.initial_record_width_scale,
"record_depth": model.record_pair_network.record_depth,
}
# Save a composite dictionary containing both similarity info and model parameters.
composite_save = {"similarity_info": composite_similarity_info, "model_params": model_params}
with open(target_directory / "model_info.pkl", "wb") as f:
pickle.dump(composite_save, f)
# --- End composite info saving ---
if isinstance(model, DLMatchingModel):
model.save_weights(target_directory / "model.weights.h5")
if hasattr(model, "optimizer") and model.optimizer:
optimizer_config = {
"class_name": model.optimizer.__class__.__name__,
"config": model.optimizer.get_config(),
}
with open(target_directory / "optimizer.pkl", "wb") as f:
pickle.dump(optimizer_config, f)
elif isinstance(model, NSMatchingModel):
model.record_pair_network.save_weights(target_directory / "record_pair_network.weights.h5")
if hasattr(model, "optimizer") and model.optimizer:
optimizer_config = {
"class_name": model.optimizer.__class__.__name__,
"config": model.optimizer.get_config(),
}
with open(target_directory / "optimizer.pkl", "wb") as f:
pickle.dump(optimizer_config, f)
else:
raise ValueError("The model must be an instance of DLMatchingModel or NSMatchingModel")
# --- Optionally save architecture diagram ---
if save_architecture:
try:
# Decide which Keras model to visualize in the standard plot
if isinstance(model, DLMatchingModel):
base_model = model
elif isinstance(model, NSMatchingModel):
base_model = model.record_pair_network
else:
base_model = None
# High-level field + record network diagram (DLMatchingModel only)
if isinstance(model, DLMatchingModel):
field_sizes, record_sizes = _extract_two_stage_arch(model)
_draw_highlevel_architecture(
field_sizes,
record_sizes,
out_path=target_directory / "architecture.png",
)
except Exception as e:
print(f"Warning: could not save architecture diagram: {e}")
print(f"Model successfully saved to {target_directory}")
[docs]
@staticmethod
def load(model_directory: Path) -> typing.Union[DLMatchingModel, NSMatchingModel]:
"""
Load a model from a specified directory.
Parameters
----------
model_directory : Path
The directory containing the saved model.
Returns
-------
DLMatchingModel or NSMatchingModel
The loaded model.
"""
model_directory = Path(model_directory) / "model"
if not model_directory.exists():
raise FileNotFoundError(f"Model directory '{model_directory}' does not exist.")
# --- Load composite model info (similarity info and model parameters) ---
with open(model_directory / "model_info.pkl", "rb") as f:
composite_save = pickle.load(f)
composite_similarity_info = composite_save["similarity_info"]
model_params = composite_save["model_params"]
# Reconstruct the original similarity_map as expected by DLMatchingModel:
# (a plain dict mapping each field to its list of metric names)
original_similarity_map = {field: info["metrics"] for field, info in composite_similarity_info.items()}
# IMPORTANT: Reconstruct a SimilarityMap instance from the plain dict.
similarity_map_instance = SimilarityMap(original_similarity_map)
# Compute aggregated sizes in the order of fields.
fields = list(composite_similarity_info.keys())
aggregated_sizes = [composite_similarity_info[field]["aggregated_size"] for field in fields]
# --- End loading composite info ---
if (model_directory / "model.weights.h5").exists():
# Initialize the model using the reconstructed SimilarityMap instance and stored parameters.
model = DLMatchingModel(
similarity_map=similarity_map_instance,
initial_feature_width_scales=model_params["initial_feature_width_scales"],
feature_depths=model_params["feature_depths"],
initial_record_width_scale=model_params["initial_record_width_scale"],
record_depth=model_params["record_depth"],
)
input_shapes = [tf.TensorShape([None, s]) for s in aggregated_sizes]
model.build(input_shapes=input_shapes)
# --- Build dummy inputs as a list of tensors (one per field) ---
# Each dummy tensor has shape (1, aggregated_size) for that field.
dummy_tensors = [
tf.zeros((1, composite_similarity_info[field]["aggregated_size"]))
for field in fields
]
# --- End dummy inputs ---
_ = model(dummy_tensors) # Forward pass to instantiate all sublayers.
model.load_weights(model_directory / "model.weights.h5")
if (model_directory / "optimizer.pkl").exists():
with open(model_directory / "optimizer.pkl", "rb") as f:
optimizer_config = pickle.load(f)
optimizer_class = getattr(tf.keras.optimizers, optimizer_config["class_name"])
model.optimizer = optimizer_class.from_config(optimizer_config["config"])
elif (model_directory / "record_pair_network.weights.h5").exists():
model = NSMatchingModel(similarity_map_instance)
model.compile()
model.record_pair_network.load_weights(model_directory / "record_pair_network.weights.h5")
if (model_directory / "optimizer.pkl").exists():
with open(model_directory / "optimizer.pkl", "rb") as f:
optimizer_config = pickle.load(f)
optimizer_class = getattr(tf.keras.optimizers, optimizer_config["class_name"])
model.optimizer = optimizer_class.from_config(optimizer_config["config"])
else:
raise ValueError("Invalid model directory: neither DLMatchingModel nor NSMatchingModel was detected.")
return model
def _extract_two_stage_arch(dl_model: "DLMatchingModel"):
"""
Extract architecture of field networks and record-pair network.
Returns
-------
field_sizes : dict[str, list[int]]
{field_name: [in_dim, h1, ..., out_dim]} for each field network.
The key is a human-readable label like "current_name ~ alternative_name".
record_sizes : list[int]
[in_dim, h1, ..., out_dim] for the record-pair network.
"""
# Work on inner RecordPairNetwork if present
base = getattr(dl_model, "record_pair_network", dl_model)
field_sizes: Dict[str, List[int]] = {}
# ------------------------------------------------------------------
# Field networks: list of FieldPairNetwork + association keys
# ------------------------------------------------------------------
nets = getattr(base, "field_networks", None)
sim_map = getattr(base, "similarity_map", None)
if nets is not None and sim_map is not None:
# Association keys, e.g. "current_name~alternative_name", "lat~lat_noise"
assoc_keys = list(sim_map.instructions.keys())
# Zip in order: one field net per association
for key, net in zip(assoc_keys, nets):
# Build a nice label using "~" as in your original similarity map
parts = [p.strip() for p in key.split("~")]
if len(parts) == 1:
label = parts[0]
else:
label = f"{parts[0]} ~ {parts[1]}"
sizes: List[int] = []
# Input dimension: FieldPairNetwork.size if available
in_dim = getattr(net, "size", None)
if in_dim is not None:
sizes.append(int(in_dim))
# Hidden + output units: from net.field_layers if present,
# otherwise from the Keras layers directly
field_layers = getattr(net, "field_layers", None)
if field_layers is None:
field_layers = [
lyr for lyr in getattr(net, "layers", [])
if isinstance(lyr, _layers.Dense)
]
for lyr in field_layers:
if isinstance(lyr, _layers.Dense):
sizes.append(int(lyr.units))
if sizes:
field_sizes[label] = sizes
# ------------------------------------------------------------------
# Fallback: approximate from similarity_map + hyperparams
# ------------------------------------------------------------------
if not field_sizes and sim_map is not None:
assoc_keys = list(sim_map.instructions.keys())
assoc_sizes = sim_map.association_sizes()
init_scale = getattr(base, "initial_feature_width_scales", None)
depth = getattr(base, "feature_depths", None)
for key, in_dim in zip(assoc_keys, assoc_sizes):
parts = [p.strip() for p in key.split("~")]
if len(parts) == 1:
label = parts[0]
else:
label = f"{parts[0]} ~ {parts[1]}"
sizes = [int(in_dim)]
if init_scale is not None and depth is not None:
width = int(in_dim * init_scale)
for _ in range(int(depth)):
sizes.append(width)
sizes.append(1) # scalar field prediction
field_sizes[label] = sizes
# ------------------------------------------------------------------
# Record-pair (record) network
# ------------------------------------------------------------------
record_sizes: List[int] = []
rec_layers = getattr(base, "record_layers", None)
if rec_layers:
first_dense = next(
(lyr for lyr in rec_layers if isinstance(lyr, _layers.Dense)),
None,
)
if first_dense is not None:
in_dim = int(first_dense.kernel.shape[0])
record_sizes.append(in_dim)
for lyr in rec_layers:
if isinstance(lyr, _layers.Dense):
record_sizes.append(int(lyr.units))
return field_sizes, record_sizes
def _draw_highlevel_architecture(field_sizes: Dict[str, List[int]],
record_sizes: List[int],
out_path: Path):
"""
Draw a high-level two-stage architecture diagram using graphviz.
"""
g = Digraph("DLMatchingModel", format="png")
g.attr(
rankdir="TB",
fontsize="12",
labelloc="t",
label="DLMatchingModel",
fontname="Helvetica-Bold",
)
# ----------------------------
# Field Networks cluster
# ----------------------------
with g.subgraph(name="cluster_fields") as c:
c.attr(
label="Field Networks",
style="rounded",
color="black",
labelloc="t",
fontsize="12",
fontname="Helvetica",
)
field_pred_nodes = []
for i, (field, sizes) in enumerate(field_sizes.items()):
# Field network node (pink)
net_label = f"Field Network\n({field})\\n" + " → ".join(str(s) for s in sizes)
net_node = f"field_{i}_net"
c.node(
net_node,
label=net_label,
shape="box",
style="rounded,filled",
fillcolor="#f4cccc", # light pink
fontname="Helvetica",
fontsize="10",
)
# Field prediction node (yellow)
pred_label = f"Field\nPrediction"
pred_node = f"field_{i}_pred"
c.node(
pred_node,
label=pred_label,
shape="box",
style="rounded,filled",
fillcolor="#fff2cc", # light yellow
fontname="Helvetica",
fontsize="10",
)
# Edge: field network -> field prediction
c.edge(net_node, pred_node)
field_pred_nodes.append(pred_node)
# ----------------------------
# Record Network
# ----------------------------
if record_sizes:
rec_label = "Record Network\\n" + " → ".join(str(s) for s in record_sizes)
else:
rec_label = "Record Network"
g.node(
"record_net",
rec_label,
shape="box",
style="rounded,filled",
fillcolor="#f4cccc",
fontname="Helvetica",
fontsize="11",
)
# ----------------------------
# Record Prediction
# ----------------------------
g.node(
"record_pred",
"Record\nPrediction",
shape="box",
style="rounded,filled",
fillcolor="#fff2cc",
fontname="Helvetica",
fontsize="11",
)
# ----------------------------
# Arrows from all field predictions to record net
# ----------------------------
for pn in field_pred_nodes:
g.edge(pn, "record_net")
g.edge("record_net", "record_pred")
# ----------------------------
# Render
# ----------------------------
out_stem = out_path.with_suffix("") # Graphviz appends .png
g.render(filename=str(out_stem), format="png", cleanup=True)
[docs]
class EpochEndSaver(tf.keras.callbacks.Callback):
"""
Custom Keras callback to save weights and biases at the end of every epoch
using the `Model.save(...)` static method.
"""
[docs]
def __init__(self, base_dir: Path, model_name: str):
"""
Parameters
----------
base_dir : Path
The root directory under which the model subdirectories will be created.
For instance: Path(__file__).resolve().parent / MODEL_NAME
model_name : str
A short identifier for the model. Each epoch’s directory will be
base_dir / model_name / "epoch_<NN>"
"""
super().__init__()
self.base_dir = Path(base_dir)
self.model_name = model_name
[docs]
def on_epoch_end(self, epoch: int, logs=None):
"""
At the end of each epoch, call Model.save(...) so that
weights and optimizer state are pickled as per your spec.
"""
# epoch is zero‐indexed, but we probably want to save as "epoch_01", etc.
epoch_index = epoch + 1
epoch_dir_name = f"epoch_{epoch_index:02d}"
# Build the directory where we want to dump model info & weights
target_directory = self.base_dir / self.model_name / "checkpoints"
# Ensure the parent directory exists; your Model.save(...) will
# create the exact "model" subdirectory under this path.
target_directory.mkdir(parents=True, exist_ok=True)
# The checkpoints `self.model` attribute is the actual keras.Model (or subclass).
## We only proceed if it’s an instance of DLMatchingModel or NSMatchingModel:
if not isinstance(self.model, (DLMatchingModel, NSMatchingModel)):
raise ValueError(
f"`EpochEndSaver` expected DLMatchingModel or NSMatchingModel, got {type(self.model)}"
)
# Now call your custom save function:
Model.save(
model=self.model,
target_directory=target_directory,
name=epoch_dir_name
)