-
Notifications
You must be signed in to change notification settings - Fork 34
support NvNMD-train and NvNMD-explore #298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
73b4a48
7f60686
5255d2e
91a6620
48aac9c
8a6918c
ff5da8d
2d4e32d
c1dd741
6b3a0e1
af591a4
2c0ea10
6c69e4e
77a0e3a
afecaad
e542b0a
ca17fe1
694889d
daae459
26ab174
c4e3f46
a89bd70
2aa102b
c1e46ac
3791618
e5a16ec
8064b74
228f620
b73ea11
ba7fc22
e329de9
fcb7747
b8bdce6
35804c5
26db68f
849ffeb
d7210b9
05ba10f
38c50ea
d33ce44
d7eac9a
ed4cbbe
179fa31
40d9c82
3e83044
a232ae1
7a5924d
ede835a
86b2a66
2b14439
2a25efe
3c8982c
aa2e96b
c50b411
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,6 +111,8 @@ | |
RunDPTrain, | ||
RunLmp, | ||
RunLmpHDF5, | ||
RunNvNMD, | ||
RunNvNMDTrain, | ||
RunRelax, | ||
RunRelaxHDF5, | ||
SelectConfs, | ||
|
@@ -182,6 +184,17 @@ def make_concurrent_learning_op( | |
valid_data=valid_data, | ||
optional_files=train_optional_files, | ||
) | ||
elif train_style == "dp-nvnmd": | ||
prep_run_train_op = PrepRunDPTrain( | ||
"prep-run-nvnmd-train", | ||
PrepDPTrain, | ||
RunNvNMDTrain, | ||
prep_config=prep_train_config, | ||
run_config=run_train_config, | ||
upload_python_packages=upload_python_packages, | ||
valid_data=valid_data, | ||
optional_files=train_optional_files, | ||
) | ||
else: | ||
raise RuntimeError(f"unknown train_style {train_style}") | ||
if explore_style == "lmp": | ||
|
@@ -193,6 +206,15 @@ def make_concurrent_learning_op( | |
run_config=run_explore_config, | ||
upload_python_packages=upload_python_packages, | ||
) | ||
elif "lmp-nvnmd" in explore_style: | ||
prep_run_explore_op = PrepRunLmp( | ||
"prep-run-nvnmd", | ||
PrepLmp, | ||
RunNvNMD, | ||
prep_config=prep_explore_config, | ||
run_config=run_explore_config, | ||
upload_python_packages=upload_python_packages, | ||
) | ||
elif "calypso" in explore_style: | ||
expl_mode = explore_style.split(":")[-1] if ":" in explore_style else "default" | ||
if expl_mode == "merge": | ||
|
@@ -286,7 +308,7 @@ def make_naive_exploration_scheduler( | |
# use npt task group | ||
explore_style = config["explore"]["type"] | ||
|
||
if explore_style == "lmp": | ||
if explore_style in ("lmp", "lmp-nvnmd"): | ||
return make_lmp_naive_exploration_scheduler(config) | ||
elif "calypso" in explore_style or explore_style == "diffcsp": | ||
return make_naive_exploration_scheduler_without_conf(config, explore_style) | ||
|
@@ -374,6 +396,7 @@ def make_lmp_naive_exploration_scheduler(config): | |
output_nopbc = config["explore"]["output_nopbc"] | ||
conf_filters = get_conf_filters(config["explore"]["filters"]) | ||
use_ele_temp = config["inputs"]["use_ele_temp"] | ||
config["explore"]["type"] | ||
scheduler = ExplorationScheduler() | ||
# report | ||
conv_style = convergence.pop("type") | ||
|
@@ -506,6 +529,16 @@ def workflow_concurrent_learning( | |
else None | ||
) | ||
config["train"]["numb_models"] = 1 | ||
|
||
elif train_style == "dp-nvnmd": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not reuse the logic for train_style == "dp"? |
||
init_models_paths = config["train"].get("init_models_paths", None) | ||
numb_models = config["train"]["numb_models"] | ||
if init_models_paths is not None and len(init_models_paths) != numb_models: | ||
raise RuntimeError( | ||
f"{len(init_models_paths)} init models provided, which does " | ||
"not match numb_models={numb_models}" | ||
) | ||
|
||
else: | ||
raise RuntimeError(f"unknown params, train_style: {train_style}") | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -109,6 +109,10 @@ def get_confs( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
conf_filters: Optional["ConfFilters"] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
optional_outputs: Optional[List[Path]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> dpdata.MultiSystems: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from ase.io import ( # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
read, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ntraj = len(trajs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ele_temp = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if optional_outputs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -123,12 +127,16 @@ def get_confs( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
traj = StringIO(trajs[ii].get_data()) # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
traj = trajs[ii] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss = dpdata.System(traj, fmt=traj_fmt, type_map=type_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss.nopbc = self.nopbc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ele_temp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.set_ele_temp(ss, ele_temp[ii]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss = ss.sub_system(id_selected[ii]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ms.append(ss) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ss = dpdata.System(traj, fmt=traj_fmt, type_map=type_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ss = read( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
str(traj), format="lammps-dump-text", index=":", specorder=type_map | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for jj in id_selected[ii]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
s = dpdata.System(ss[jj], fmt="ase/structure", type_map=type_map) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
127
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. StringIO handling is broken –
- ss = read(
- str(traj), format="lammps-dump-text", index=":", specorder=type_map
- )
+ ss = read(
+ traj, # pass the file-object directly
+ format="lammps-dump-text",
+ index=":",
+ specorder=type_map,
+ ) This keeps the code path for on-disk files unchanged while restoring support 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
s.nopbc = self.nopbc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if ele_temp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.set_ele_temp(s, ele_temp[ii]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ms.append(s) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if conf_filters is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ms = conf_filters.check(ms) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return ms |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -50,6 +50,7 @@ def make_lmp_input( | |||||||||||||
nopbc: bool = False, | ||||||||||||||
max_seed: int = 1000000, | ||||||||||||||
deepmd_version="2.0", | ||||||||||||||
nvnmd_version=None, | ||||||||||||||
trj_seperate_files=True, | ||||||||||||||
pimd_bead: Optional[str] = None, | ||||||||||||||
): | ||||||||||||||
|
@@ -69,9 +70,9 @@ def make_lmp_input( | |||||||||||||
ret += "variable THERMO_FREQ equal %d\n" % trj_freq | ||||||||||||||
ret += "variable DUMP_FREQ equal %d\n" % trj_freq | ||||||||||||||
ret += "variable TEMP equal %f\n" % temp | ||||||||||||||
if ele_temp_f is not None: | ||||||||||||||
if ele_temp_f is not None and nvnmd_version is None: | ||||||||||||||
ret += "variable ELE_TEMP equal %f\n" % ele_temp_f | ||||||||||||||
if ele_temp_a is not None: | ||||||||||||||
if ele_temp_a is not None and nvnmd_version is None: | ||||||||||||||
ret += "variable ELE_TEMP equal %f\n" % ele_temp_a | ||||||||||||||
if pres is not None: | ||||||||||||||
ret += "variable PRES equal %f\n" % pres | ||||||||||||||
|
@@ -106,12 +107,14 @@ def make_lmp_input( | |||||||||||||
if pimd_bead is not None | ||||||||||||||
else lmp_model_devi_name | ||||||||||||||
) | ||||||||||||||
if Version(deepmd_version) < Version("1"): | ||||||||||||||
if Version(deepmd_version) < Version("1") and nvnmd_version is None: | ||||||||||||||
# 0.x | ||||||||||||||
ret += "pair_style deepmd %s ${THERMO_FREQ} %s\n" % ( | ||||||||||||||
graph_list, | ||||||||||||||
model_devi_file_name, | ||||||||||||||
) | ||||||||||||||
elif nvnmd_version is not None: | ||||||||||||||
ret += "pair_style nvnmd %s\n" % ("model.pb") | ||||||||||||||
else: | ||||||||||||||
Comment on lines
+116
to
118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using a configurable model name instead of hardcoding "model.pb". The NVNMD pair style uses a hardcoded model filename, while the DeepMD branches use the Apply this diff to use the first model from the graphs list: elif nvnmd_version is not None:
- ret += "pair_style nvnmd %s\n" % ("model.pb")
+ ret += "pair_style nvnmd %s\n" % (graphs[0] if graphs else "model.pb") 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
Comment on lines
+116
to
118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainConsider making the model filename configurable. The nvnmd pair style uses a hardcoded 🏁 Script executed: #!/bin/bash
# Check if there are other places in the codebase that assume "model.pb" for nvnmd
rg -n "model\.pb" --type py | grep -i nvnmd Length of output: 1986 🏁 Script executed: #!/bin/bash
# Display the LAMMPS input generator around the nvnmd section for context
sed -n '1,200p' dpgen2/exploration/task/lmp/lmp_input.py Length of output: 7322 Make nvnmd model filenames configurable The • File - elif nvnmd_version is not None:
- ret += "pair_style nvnmd %s\n" % ("model.pb")
+ elif nvnmd_version is not None:
+ # allow passing specific model filenames per graph
+ # e.g. ["nvnmd_cnn/frozen_model.pb", "nvnmd_qnn/model.pb"]
+ model_files = " ".join(nvnmd_model_files)
+ ret += "pair_style nvnmd %s %s\n" % (graph_list, model_files) • Add a new parameter (e.g.
🤖 Prompt for AI Agents
|
||||||||||||||
# 1.x | ||||||||||||||
keywords = "" | ||||||||||||||
|
@@ -135,17 +138,28 @@ def make_lmp_input( | |||||||||||||
ret += "thermo_style custom step temp pe ke etotal press vol lx ly lz xy xz yz\n" | ||||||||||||||
ret += "thermo ${THERMO_FREQ}\n" | ||||||||||||||
if trj_seperate_files: | ||||||||||||||
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n" | ||||||||||||||
if nvnmd_version is None: | ||||||||||||||
ret += "dump 1 all custom ${DUMP_FREQ} traj/*.lammpstrj id type x y z fx fy fz\n" | ||||||||||||||
else: | ||||||||||||||
ret += "dump 1 all custom ${DUMP_FREQ} ${rerun}_traj/*.lammpstrj id type x y z fx fy fz\n" | ||||||||||||||
else: | ||||||||||||||
lmp_traj_file_name = ( | ||||||||||||||
lmp_pimd_traj_name % pimd_bead if pimd_bead is not None else lmp_traj_name | ||||||||||||||
) | ||||||||||||||
ret += ( | ||||||||||||||
"dump 1 all custom ${DUMP_FREQ} %s id type x y z fx fy fz\n" | ||||||||||||||
% lmp_traj_file_name | ||||||||||||||
) | ||||||||||||||
if nvnmd_version is None: | ||||||||||||||
ret += ( | ||||||||||||||
"dump 1 all custom ${DUMP_FREQ} %s id type x y z fx fy fz\n" | ||||||||||||||
% lmp_traj_file_name | ||||||||||||||
) | ||||||||||||||
else: | ||||||||||||||
ret += ( | ||||||||||||||
"dump 1 all custom ${DUMP_FREQ} ${rerun}_%s id type x y z fx fy fz\n" | ||||||||||||||
% lmp_traj_file_name | ||||||||||||||
) | ||||||||||||||
ret += "restart 10000 dpgen.restart\n" | ||||||||||||||
ret += "\n" | ||||||||||||||
if nvnmd_version is not None: | ||||||||||||||
ret += 'if "${rerun} > 0" then "jump SELF rerun"\n' | ||||||||||||||
if pka_e is None: | ||||||||||||||
ret += 'if "${restart} == 0" then "velocity all create ${TEMP} %d"' % ( | ||||||||||||||
Comment on lines
+161
to
164
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion
The jump predicates on +# declare default
+variable RERUN equal 0 or document that callers must always pass
🤖 Prompt for AI Agents
|
||||||||||||||
random.randrange(max_seed - 1) + 1 | ||||||||||||||
|
@@ -193,4 +207,12 @@ def make_lmp_input( | |||||||||||||
ret += "\n" | ||||||||||||||
ret += "timestep %f\n" % dt | ||||||||||||||
ret += "run ${NSTEPS} upto\n" | ||||||||||||||
if nvnmd_version is not None: | ||||||||||||||
ret += "jump SELF end\n" | ||||||||||||||
ret += "label rerun\n" | ||||||||||||||
if trj_seperate_files: | ||||||||||||||
ret += "rerun 0_traj/*.lammpstrj dump x y z fx fy fz add yes\n" | ||||||||||||||
else: | ||||||||||||||
ret += "rerun 0_%s dump x y z fx fy fz add yes\n" % lmp_traj_name | ||||||||||||||
ret += "label end\n" | ||||||||||||||
return ret |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused variable extraction.
The explore type is extracted from the config but never used.
- config["explore"]["type"]
📝 Committable suggestion
🤖 Prompt for AI Agents