Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/multimm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,16 @@ def clean_fields(cls, data: Any) -> Any:
COB_EA: float = Field(default=1.0, description="Energy strength for A compartment.")
COB_EB: float = Field(default=2.0, description="Energy strength for B compartment.")
SCB_USE_SUBCOMPARTMENT_BLOCKS: Boolean = Field(default=False, description="Use Subcompartment Blocks.")
SCB_DISTANCE: Optional[OpenMMQuantity] = Field(default=None, description="Block copolymer equilibrium distance for chromosomal blocks.")
SCB_DISTANCE: Optional[OpenMMQuantity] = Field(
default=None, description="Block copolymer equilibrium distance for chromosomal blocks."
)
SCB_EA1: float = Field(default=1.0, description="Energy strength for A1 compartment.")
SCB_EA2: float = Field(default=1.33, description="Energy strength for A2 compartment.")
SCB_EB1: float = Field(default=1.66, description="Energy strength for B1 compartment.")
SCB_EB2: float = Field(default=2.0, description="Energy strength for B2 compartment.")
IBL_USE_B_LAMINA_INTERACTION: Boolean = Field(default=False, description="Interactions of B compartment with lamina.")
IBL_USE_B_LAMINA_INTERACTION: Boolean = Field(
default=False, description="Interactions of B compartment with lamina."
)
IBL_SCALE: float = Field(default=400.0, description="Scaling factor for B comoartment interaction with lamina.")
CF_USE_CENTRAL_FORCE: Boolean = Field(default=False, description="Attraction of smaller chromosomes.")
CF_STRENGTH: float = Field(default=20.0, description="Strength of Interaction")
Expand Down Expand Up @@ -298,7 +302,8 @@ def clean_fields(cls, data: Any) -> Any:
"Options: polynomial (default, handcrafted potential), "
"gaussian (soft collapse kernel), "
"saturating (soft-core bounded attraction)."
))
),
)

CENTRAL_FORCE_TYPE: str = Field(
default="harmonic",
Expand All @@ -309,4 +314,5 @@ def clean_fields(cls, data: Any) -> Any:
"harmonic (default, quadratic confinement around R1), "
"gaussian (soft nucleolar enrichment field), "
"logistic (soft-core radial partitioning with smooth boundary)."
))
),
)
22 changes: 9 additions & 13 deletions src/multimm/logger.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import logging
import sys


class _ColorFormatter(logging.Formatter):
"""Pretty colored formatter for terminal logs."""

COLORS = {
"DEBUG": "\033[1;34m", # blue
"INFO": "\033[1;32m", # green
"DEBUG": "\033[1;34m", # blue
"INFO": "\033[1;32m", # green
"WARNING": "\033[1;33m", # yellow
"ERROR": "\033[1;31m", # red
"CRITICAL": "\033[1;41m", # red background
"ERROR": "\033[1;31m", # red
"CRITICAL": "\033[1;41m", # red background
}

RESET = "\033[0m"
BOLD = "\033[1m"

def format(self, record):
level_color = self.COLORS.get(record.levelname, "")

record.levelname = f"{level_color}{record.levelname:<8}{self.RESET}"
record.name = f"\033[1;35m{record.name}{self.RESET}"
record.msg = str(record.msg)
Expand All @@ -26,10 +27,7 @@ def format(self, record):


def setup_logger(level=logging.INFO, debug=False):
"""
Clean, colored logger for simulation pipelines.
"""

"""Clean, colored logger for simulation pipelines."""
root = logging.getLogger()

# Avoid duplicate handlers
Expand All @@ -39,13 +37,11 @@ def setup_logger(level=logging.INFO, debug=False):
handler = logging.StreamHandler(sys.stdout)

formatter = _ColorFormatter(
"\033[1;36m[%(asctime)s]\033[0m "
"%(levelname)s "
"%(name)s: %(message)s",
"\033[1;36m[%(asctime)s]\033[0m " "%(levelname)s " "%(name)s: %(message)s",
datefmt="%H:%M:%S",
)

handler.setFormatter(formatter)

root.setLevel(logging.DEBUG if debug else level)
root.addHandler(handler)
root.addHandler(handler)
97 changes: 26 additions & 71 deletions src/multimm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from .initial_structure_tools import build_init_mmcif, write_cmm, write_mmcif_chrom
from .nucleosome_interpolation import NucleosomeInterpolation
from .utils import *
from .plots import *

from .utils import *

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -413,10 +412,7 @@ def add_chromosomal_blocks(self):

logger.info("Using polynomial chromosomal self-attraction model")

self.chrom_block_force.setEnergyFunction(
"E*(k_C*r^4 - r^3 + r^2); "
"E = dE*delta(chrom1-chrom2)"
)
self.chrom_block_force.setEnergyFunction("E*(k_C*r^4 - r^3 + r^2); " "E = dE*delta(chrom1-chrom2)")

logger.info(f"k_C={self.args.CHB_KC}, dE={self.args.CHB_DE}")

Expand All @@ -425,10 +421,7 @@ def add_chromosomal_blocks(self):

logger.info("Using Gaussian chromosomal collapse kernel")

self.chrom_block_force.setEnergyFunction(
"-E * exp(-k_C*r^2); "
"E = dE*delta(chrom1-chrom2)"
)
self.chrom_block_force.setEnergyFunction("-E * exp(-k_C*r^2); " "E = dE*delta(chrom1-chrom2)")

logger.info(f"k_C={self.args.CHB_KC}, dE={self.args.CHB_DE}")

Expand All @@ -437,10 +430,7 @@ def add_chromosomal_blocks(self):

logger.info("Using saturating chromosomal interaction model")

self.chrom_block_force.setEnergyFunction(
"-E / (1 + k_C*r^2); "
"E = dE*delta(chrom1-chrom2)"
)
self.chrom_block_force.setEnergyFunction("-E / (1 + k_C*r^2); " "E = dE*delta(chrom1-chrom2)")

logger.info(f"k_C={self.args.CHB_KC}, dE={self.args.CHB_DE}")

Expand Down Expand Up @@ -497,12 +487,10 @@ def add_Blamina_interaction(self):

# 1. DEFAULT: sinusoidal shell (original)
if mode == "sin":

logger.info("Using sinusoidal lamina shell model")

self.Blamina_force.setEnergyFunction(
"B*(sin(pi*(r-R1)/(R2-R1))^8 - 1)*(delta(s+1)+delta(s+2)); " + r_expr
)
self.Blamina_force.setEnergyFunction("B*(sin(pi*(r-R1)/(R2-R1))^8 - 1)*(delta(s+1)+delta(s+2)); " + r_expr)
self.Blamina_force.addGlobalParameter("pi", np.pi)
logger.info(f"Shell radii: R1={self.radius1}, R2={self.radius2}")

Expand All @@ -511,8 +499,7 @@ def add_Blamina_interaction(self):

logger.info("Using Gaussian lamina shell model (two-layer attraction)")
self.Blamina_force.setEnergyFunction(
"-B*(exp(-(r-R1)^2/(2*sigma^2)) + exp(-(r-R2)^2/(2*sigma^2)))"
"*(delta(s+1)+delta(s+2)); " + r_expr
"-B*(exp(-(r-R1)^2/(2*sigma^2)) + exp(-(r-R2)^2/(2*sigma^2)))" "*(delta(s+1)+delta(s+2)); " + r_expr
)
sigma = 0.1 * (self.radius2 - self.radius1)
self.Blamina_force.addGlobalParameter("sigma", sigma)
Expand All @@ -521,9 +508,7 @@ def add_Blamina_interaction(self):
# 3. HARMONIC SHELL (pull to mid-shell)
elif mode == "harmonic_shell":
logger.info("Using harmonic lamina shell model (mid-shell attraction)")
self.Blamina_force.setEnergyFunction(
"B*(r - r0)^2*(delta(s+1)+delta(s+2)); " + r_expr
)
self.Blamina_force.setEnergyFunction("B*(r - r0)^2*(delta(s+1)+delta(s+2)); " + r_expr)
r0 = 0.5 * (self.radius1 + self.radius2)
self.Blamina_force.addGlobalParameter("r0", r0)
logger.info(f"r0 (mid-shell) = {r0}")
Expand All @@ -532,8 +517,7 @@ def add_Blamina_interaction(self):
elif mode == "logistic_shell":
logger.info("Using logistic lamina shell model (soft boundaries)")
self.Blamina_force.setEnergyFunction(
"-B*(1/(1+exp((r-R2)/lambda)) + 1/(1+exp(-(r-R1)/lambda)))"
"*(delta(s+1)+delta(s+2)); " + r_expr
"-B*(1/(1+exp((r-R2)/lambda)) + 1/(1+exp(-(r-R1)/lambda)))" "*(delta(s+1)+delta(s+2)); " + r_expr
)
lam = 0.05 * (self.radius2 - self.radius1)
self.Blamina_force.addGlobalParameter("lambda", lam)
Expand All @@ -550,10 +534,7 @@ def add_Blamina_interaction(self):
self.system.addForce(self.Blamina_force)

def add_central_force(self):
"""
Central nucleolar attraction with chromosome-size bias.
"""

"""Central nucleolar attraction with chromosome-size bias."""
mode = getattr(self.args, "CENTRAL_FORCE_TYPE", "harmonic")

self.central_force = mm.CustomExternalForce("0")
Expand Down Expand Up @@ -581,9 +562,7 @@ def add_central_force(self):
if mode == "harmonic":
logger.info("Using harmonic central attraction")

self.central_force.setEnergyFunction(
f"G*chrom_s*({r}-R1)*({r}-R1)"
)
self.central_force.setEnergyFunction(f"G*chrom_s*({r}-R1)*({r}-R1)")

# ============================================================
# 2. GAUSSIAN CENTER
Expand All @@ -594,9 +573,7 @@ def add_central_force(self):
sigma = 0.5 * self.radius1
self.central_force.addGlobalParameter("sigma", sigma)

self.central_force.setEnergyFunction(
f"-G*chrom_s*exp(-({r}*{r})/(2*sigma*sigma))"
)
self.central_force.setEnergyFunction(f"-G*chrom_s*exp(-({r}*{r})/(2*sigma*sigma))")

# ============================================================
# 3. LOGISTIC CORE
Expand All @@ -607,9 +584,7 @@ def add_central_force(self):
lam = 0.2 * self.radius1
self.central_force.addGlobalParameter("lambda", lam)

self.central_force.setEnergyFunction(
f"-G*chrom_s*(1/(1+exp(({r}-R1)/lambda)))"
)
self.central_force.setEnergyFunction(f"-G*chrom_s*(1/(1+exp(({r}-R1)/lambda)))")

else:
raise ValueError(f"Unknown CENTRAL_FORCE_TYPE: {mode}")
Expand All @@ -636,15 +611,13 @@ def add_harmonic_bonds(self):
self.system.addForce(self.bond_force)

def add_loops(self):
"""
Loop constraints using stable polymer bond models.
"""Loop constraints using stable polymer bond models.

Supported modes:
- harmonic (default)
- fene_safe (bounded FENE-like)
- gaussian_tether (fully smooth bounded well)
"""

mode = getattr(self.args, "LE_LOOP_FORCE_TYPE", "harmonic")

# 1. HARMONIC (unchanged baseline)
Expand All @@ -661,9 +634,7 @@ def add_loops(self):
# 2. SAFE FENE-LIKE (bounded, no singularity)
elif mode == "fene_soft":

self.loop_force = mm.CustomBondForce(
"k * (r - r0)^2 / (1 + alpha * (r - r0)^2)"
)
self.loop_force = mm.CustomBondForce("k * (r - r0)^2 / (1 + alpha * (r - r0)^2)")

self.loop_force.addPerBondParameter("r0")
self.loop_force.addPerBondParameter("k")
Expand All @@ -675,16 +646,14 @@ def add_loops(self):
r0 = self.args.LE_HARMONIC_BOND_R0 if self.args.LE_FIXED_DISTANCES else self.ds[i]

k = self.args.LE_HARMONIC_BOND_K
alpha = 1.0 / (r0 ** 2)
alpha = 1.0 / (r0**2)

self.loop_force.addBond(m, n, [r0, k, alpha])

# 3. GAUSSIAN TETHER (fully smooth bounded interaction)
elif mode == "gaussian_tether":

self.loop_force = mm.CustomBondForce(
"k * (1 - exp(-(r - r0)^2 / sigma^2))"
)
self.loop_force = mm.CustomBondForce("k * (1 - exp(-(r - r0)^2 / sigma^2))")

self.loop_force.addPerBondParameter("r0")
self.loop_force.addPerBondParameter("k")
Expand Down Expand Up @@ -724,9 +693,7 @@ def initialize_simulation(self):
logger.info("Creating initial structure...")

structure_type = (
"compartments"
if (self.Cs is not None and len(np.unique(self.Cs)) <= 3)
else "subcompartments"
"compartments" if (self.Cs is not None and len(np.unique(self.Cs)) <= 3) else "subcompartments"
)
logger.info(f"Detected structure type: {structure_type}")

Expand Down Expand Up @@ -811,7 +778,6 @@ def initialize_simulation(self):

def add_forcefield(self):
"""Here we define the forcefield of MultiMM."""

logger.info("Importing forcefield...")

if self.args.EV_USE_EXCLUDED_VOLUME:
Expand Down Expand Up @@ -929,11 +895,7 @@ def run_md(self):

self.simulation.step(self.args.SIM_SAMPLING_STEP)

state = self.simulation.context.getState(
getPositions=True,
getEnergy=True,
getVelocities=True
)
state = self.simulation.context.getState(getPositions=True, getEnergy=True, getVelocities=True)

# STEP (always safe)
step = state.getStepCount()
Expand Down Expand Up @@ -986,10 +948,7 @@ def run_md(self):
self.state.getPositions(),
open(self.save_path + "model/MultiMM_afterMD.cif", "w"),
)
plot_md_thermo(
self.md_history,
self.save_path
)
plot_md_thermo(self.md_history, self.save_path)
logger.info(
f"Everything is done! Simulation finished succesfully!\nMD finished in {elapsed//3600:.0f} hours, {elapsed%3600//60:.0f} minutes and {elapsed%60:.0f} seconds. ---\n"
)
Expand Down Expand Up @@ -1079,7 +1038,8 @@ def make_plots(self):
is_comp = self.Cs is not None and len(self.Cs) > 0

def _viz_and_heat(cif_path, out_name, colors=None):
"""Unified structure + heatmap pipeline (single source of truth)."""
"""Unified structure + heatmap pipeline (single source of
truth)."""
V = get_coordinates_cif(cif_path)

# 3D structure
Expand All @@ -1092,14 +1052,8 @@ def _viz_and_heat(cif_path, out_name, colors=None):
)

# heatmap (always)
if self.args.N_BEADS<50000:
get_heatmap(
cif_path,
viz=True,
save=True,
save_path=self.save_path + f"plots",
name=out_name
)
if self.args.N_BEADS < 50000:
get_heatmap(cif_path, viz=True, save=True, save_path=self.save_path + f"plots", name=out_name)
else:
logger.warning("\033[93mHeatmap creation skipped because system is too large for visualization.\033[0m")

Expand All @@ -1110,7 +1064,8 @@ def _viz_and_heat(cif_path, out_name, colors=None):
name=out_name,
)

plot_projection(
if is_comp:
plot_projection(
get_coordinates_mm(self.state.getPositions()),
self.Cs,
save_path=self.save_path,
Expand Down
Loading