Skip to content

Commit e6c2f0b

Browse files
committed
add additional stage to optimize combination of 3 bsdfs that have been mixed with 2 nodes.
1 parent c61991f commit e6c2f0b

File tree

2 files changed

+241
-7
lines changed

2 files changed

+241
-7
lines changed

source/MaterialXGenShader/ShaderGraph.cpp

Lines changed: 237 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,22 @@ ShaderNode* ShaderGraph::getNode(const string& name)
796796
return it != _nodeMap.end() ? it->second.get() : nullptr;
797797
}
798798

799+
bool ShaderGraph::removeNode(ShaderNode* node)
800+
{
801+
auto mapIt = _nodeMap.find(node->getName());
802+
auto vecIt = std::find(_nodeOrder.begin(), _nodeOrder.end(), node);
803+
if (mapIt == _nodeMap.end() || vecIt == _nodeOrder.end())
804+
return false;
805+
806+
for (ShaderInput* input : node->getInputs())
807+
{
808+
input->breakConnection();
809+
}
810+
_nodeMap.erase(mapIt);
811+
_nodeOrder.erase(vecIt);
812+
return true;
813+
}
814+
799815
const ShaderNode* ShaderGraph::getNode(const string& name) const
800816
{
801817
return const_cast<ShaderGraph*>(this)->getNode(name);
@@ -976,9 +992,17 @@ void ShaderGraph::optimize(GenContext& context)
976992
const vector<ShaderNode*> nodeList = getNodes();
977993
for (ShaderNode* node : nodeList)
978994
{
995+
// first check the node is still in the graph, and hasn't been removed by a
996+
// prior optimization
997+
if (!getNode(node->getName()))
998+
continue;
999+
9791000
if (node->hasClassification(ShaderNode::Classification::MIX_BSDF))
9801001
{
981-
optimizeMixBsdf(node, context);
1002+
if (!optimizeMixMixBsdf(node, context))
1003+
{
1004+
optimizeMixBsdf(node, context);
1005+
}
9821006
}
9831007
}
9841008
}
@@ -1018,7 +1042,7 @@ void ShaderGraph::optimize(GenContext& context)
10181042
//
10191043
// Motivation - the new graph is more efficient for shader backends to optimize
10201044
// away possible expensive BSDF nodes that might not be used at all.
1021-
void ShaderGraph::optimizeMixBsdf(ShaderNode* mixNode, GenContext& context)
1045+
bool ShaderGraph::optimizeMixBsdf(ShaderNode* mixNode, GenContext& context)
10221046
{
10231047
// criteria for optimization...
10241048
// * upstream nodes for MixBsdf node both have `weight` input ports
@@ -1033,29 +1057,29 @@ void ShaderGraph::optimizeMixBsdf(ShaderNode* mixNode, GenContext& context)
10331057
// the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
10341058
// but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
10351059
if (!mixFgInput || !mixBgInput)
1036-
return;
1060+
return false;
10371061

10381062
auto fgNode = mixFgInput->getConnectedSibling();
10391063
auto bgNode = mixBgInput->getConnectedSibling();
10401064

10411065
// We check to see we have two upstream nodes - there are almost certainly other optimizations possible
10421066
// if this isn't true, but we will leave those for a later PR.
10431067
if (!fgNode && !bgNode)
1044-
return;
1068+
return false;
10451069

10461070
auto fgNodeWeightInput = fgNode->getInput("weight");
10471071
auto bgNodeWeightInput = bgNode->getInput("weight");
10481072

10491073
// We require both upstream nodes to have a "weight" input for this optimization to work.
10501074
if (!fgNodeWeightInput || !bgNodeWeightInput)
1051-
return;
1075+
return false;
10521076

10531077
// we also require the following list of node definitions
10541078
auto addBsdfNodeDef = _document->getNodeDef("ND_add_bsdf");
10551079
auto floatInvertNodeDef = _document->getNodeDef("ND_invert_float");
10561080
auto floatMultNodeDef = _document->getNodeDef("ND_multiply_float");
10571081
if (!addBsdfNodeDef || !floatInvertNodeDef || !floatMultNodeDef)
1058-
return;
1082+
return false;
10591083

10601084
// We meet the requirements for the optimization.
10611085
// We can now create the new nodes and connect them up.
@@ -1127,6 +1151,213 @@ void ShaderGraph::optimizeMixBsdf(ShaderNode* mixNode, GenContext& context)
11271151
{
11281152
addNode->getOutput("out")->makeConnection(conn);
11291153
}
1154+
1155+
removeNode(mixNode);
1156+
1157+
return true;
1158+
}
1159+
1160+
1161+
// Optimize the combination of two MixBsdf nodes by replacing it with a new node graph.
1162+
//
1163+
// The current nodegraph
1164+
// ┌────────┐┌────────┐┌────────┐
1165+
// │A.weight││B.weight││Mix1.mix│
1166+
// └┬───────┘└┬───────┘└┬───────┘
1167+
// ┌▽───────┐┌▽───────┐ │
1168+
// │A (BSDF)││B (BSDF)│ │
1169+
// └┬───────┘└┬───────┘ │
1170+
// ┌▽─────────▽─────────▽┐┌────────┐┌────────┐
1171+
// │Mix_1 ││C.weight││Mix2.mix│
1172+
// └┬────────────────────┘└┬───────┘└┬───────┘
1173+
// | ┌▽───────┐ |
1174+
// | │C (BSDF)│ |
1175+
// | └┬───────┘ |
1176+
// ┌▽──────────────────────▽─────────▽───────┐
1177+
// |Mix_2 |
1178+
// └─────────────────────────────────────────┘
1179+
//
1180+
// New nodegraph //TODO
1181+
// ┌────────┐
1182+
// │Mix.mix │
1183+
// └┬──────┬┘
1184+
// ┌▽─────┐│┌────────┐┌────────┐
1185+
// │Invert│││A.weight││B.weight│
1186+
// └─────┬┘│└┬───────┘└──┬─────┘
1187+
// └─│─│──┐ │
1188+
// ┌───────▽─▽┐┌▽────────▽┐
1189+
// │Multiply A││Multiply B│
1190+
// └┬─────────┘└┬─────────┘
1191+
// ┌▽───────┐┌──▽─────┐
1192+
// │A (BSDF)││B (BSDF)│
1193+
// └┬───────┘└┬───────┘
1194+
// ┌▽─────────▽┐
1195+
// │Add │
1196+
// └───────────┘
1197+
//
1198+
// Motivation - the new graph is more efficient for shader backends to optimize
1199+
// away possible expensive BSDF nodes that might not be used at all.
1200+
bool ShaderGraph::optimizeMixMixBsdf(ShaderNode* mixNode_x, GenContext& context)
1201+
{
1202+
// criteria for optimization...
1203+
// * upstream nodes for MixBsdf node both have `weight` input ports
1204+
// * We have the following node definitions available in the library
1205+
// * ND_add_bsdf
1206+
// * ND_invert_float
1207+
// * ND_multiply_float
1208+
1209+
ShaderNode* mix2Node = mixNode_x;
1210+
1211+
auto mix2FgInput = mix2Node->getInput("fg");
1212+
auto mix2BgInput = mix2Node->getInput("bg");
1213+
1214+
// the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
1215+
// but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
1216+
if (!mix2FgInput || !mix2BgInput)
1217+
return false;
1218+
1219+
auto mix2FgNode = mix2FgInput->getConnectedSibling();
1220+
auto mix2BgNode = mix2BgInput->getConnectedSibling();
1221+
1222+
// We check to see we have two upstream nodes - there are almost certainly other optimizations possible
1223+
// if this isn't true, but we will leave those for a later PR.
1224+
if (!mix2FgNode && !mix2BgNode)
1225+
return false;
1226+
1227+
// we require the node connected to the "fg" input to also be a mixBsdf node
1228+
if (!mix2FgNode->hasClassification(ShaderNode::Classification::MIX_BSDF))
1229+
return false;
1230+
1231+
ShaderNode* mix1Node = mix2FgNode;
1232+
1233+
1234+
auto mix1FgInput = mix1Node->getInput("fg");
1235+
auto mix1BgInput = mix1Node->getInput("bg");
1236+
1237+
// the standard data library ND_mix_bsdf should always have "fg" and "bg" inputs
1238+
// but we check here anyway to ensure we're not using a custom data library that doesn't follow that convention
1239+
if (!mix1FgInput || !mix1BgInput)
1240+
return false;
1241+
1242+
auto mix1FgNode = mix1FgInput->getConnectedSibling();
1243+
auto mix1BgNode = mix1BgInput->getConnectedSibling();
1244+
1245+
// We check to see we have two upstream nodes - there are almost certainly other optimizations possible
1246+
// if this isn't true, but we will leave those for a later PR.
1247+
if (!mix1FgNode && !mix1BgNode)
1248+
return false;
1249+
1250+
auto mix1FgNodeWeightInput = mix1FgNode->getInput("weight");
1251+
auto mix1BgNodeWeightInput = mix1BgNode->getInput("weight");
1252+
auto mix2BgNodeWeightInput = mix2BgNode->getInput("weight");
1253+
1254+
// We require both upstream nodes to have a "weight" input for this optimization to work.
1255+
if (!mix1FgNodeWeightInput || !mix1BgNodeWeightInput || !mix2BgNodeWeightInput)
1256+
return false;
1257+
1258+
1259+
// we also require the following list of node definitions
1260+
auto addBsdfNodeDef = _document->getNodeDef("ND_add_bsdf");
1261+
auto floatInvertNodeDef = _document->getNodeDef("ND_invert_float");
1262+
auto floatMultNodeDef = _document->getNodeDef("ND_multiply_float");
1263+
if (!addBsdfNodeDef || !floatInvertNodeDef || !floatMultNodeDef)
1264+
return false;
1265+
1266+
// We meet the requirements for the optimization.
1267+
// We can now create the new nodes and connect them up.
1268+
1269+
// Helper function to redirect the incoming connection to from one input port
1270+
// to another.
1271+
// We intentionally skip error checking here, as we're doing it below.
1272+
// If this proves useful we should make it a method somewhere, and add
1273+
// more robust error checking.
1274+
auto redirectInput = [](ShaderInput* fromPort, ShaderInput* toPort) -> void
1275+
{
1276+
auto connection = fromPort->getConnection();
1277+
if (connection)
1278+
{
1279+
// we have a connection - so transfer it
1280+
toPort->makeConnection(connection);
1281+
}
1282+
else
1283+
{
1284+
// we just remap the value.
1285+
toPort->setValue(fromPort->getValue());
1286+
}
1287+
};
1288+
1289+
// Helper function to connect two nodes together, consolidating the valiation of the ports existance.
1290+
// If this proves useful we should make it a method somewhere.
1291+
auto connectNodes = [](ShaderNode* fromNode, const string& fromPortName, ShaderNode* toNode, const string& toPortName) -> void
1292+
{
1293+
auto fromPort = fromNode->getOutput(fromPortName);
1294+
auto toPort = toNode->getInput(toPortName);
1295+
if (!fromPort || !toPort)
1296+
return;
1297+
1298+
fromPort->makeConnection(toPort);
1299+
};
1300+
1301+
auto mix1WeightInput = mix1Node->getInput("mix");
1302+
auto mix2WeightInput = mix2Node->getInput("mix");
1303+
1304+
// create nodes that represents the inverted mix values, ie. 1.0-mix
1305+
// to be used for the "bg" side of the mix
1306+
auto invertMix1Node = this->createNode(mix1Node->getName()+"_INV", floatInvertNodeDef, context);
1307+
redirectInput(mix1WeightInput, invertMix1Node->getInput("in"));
1308+
auto invertMix2Node = this->createNode(mix2Node->getName()+"_INV", floatInvertNodeDef, context);
1309+
redirectInput(mix2WeightInput, invertMix2Node->getInput("in"));
1310+
1311+
1312+
// create a multiply node to calculate the new weight value, weighted by the mix value.
1313+
auto multFg1WeightNode_intermediate = this->createNode(mix1Node->getName()+"_MULT_FG1_INTERMEDIATE", floatMultNodeDef, context);
1314+
redirectInput(mix1FgNodeWeightInput, multFg1WeightNode_intermediate->getInput("in1"));
1315+
redirectInput(mix1WeightInput, multFg1WeightNode_intermediate->getInput("in2"));
1316+
auto multFg1WeightNode = this->createNode(mix1Node->getName()+"_MULT_FG1", floatMultNodeDef, context);
1317+
redirectInput(mix2WeightInput, multFg1WeightNode->getInput("in1"));
1318+
connectNodes(multFg1WeightNode_intermediate, "out", multFg1WeightNode, "in2");
1319+
1320+
auto multBg1WeightNode_intermediate = this->createNode(mix1Node->getName()+"_MULT_BG1_INTERMEDIATE", floatMultNodeDef, context);
1321+
redirectInput(mix1BgNodeWeightInput, multBg1WeightNode_intermediate->getInput("in1"));
1322+
connectNodes(invertMix1Node, "out", multBg1WeightNode_intermediate, "in2");
1323+
auto multBg1WeightNode = this->createNode(mix1Node->getName()+"_MULT_BG1", floatMultNodeDef, context);
1324+
redirectInput(mix2WeightInput, multBg1WeightNode->getInput("in1"));
1325+
connectNodes(multBg1WeightNode_intermediate, "out", multBg1WeightNode, "in2");
1326+
1327+
auto multBg2WeightNode = this->createNode(mix2Node->getName()+"_MULT_BG2", floatMultNodeDef, context);
1328+
redirectInput(mix2BgNodeWeightInput, multBg2WeightNode->getInput("in1"));
1329+
connectNodes(invertMix2Node, "out", multBg2WeightNode, "in2");
1330+
1331+
1332+
// connect the two newly created weights to the fg and bg BSDF nodes.
1333+
connectNodes(multFg1WeightNode, "out", mix1FgNode, "weight");
1334+
connectNodes(multBg1WeightNode, "out", mix1BgNode, "weight");
1335+
connectNodes(multBg2WeightNode, "out", mix2BgNode, "weight");
1336+
1337+
1338+
// Create the ND_add_bsdf nodes that will add the three BSDF nodes with the modified weights
1339+
// this replaces the original mix nodes.
1340+
auto addNode_intermediate = this->createNode(mix2Node->getName()+"_ADD_INTERMEDIATE", addBsdfNodeDef, context);
1341+
connectNodes(mix1BgNode, "out", addNode_intermediate, "in1");
1342+
connectNodes(mix1FgNode, "out", addNode_intermediate, "in2");
1343+
1344+
auto addNode = this->createNode(mix2Node->getName()+"_ADD", addBsdfNodeDef, context);
1345+
connectNodes(addNode_intermediate, "out", addNode, "in1");
1346+
connectNodes(mix2BgNode, "out", addNode, "in2");
1347+
1348+
// Finally for all the previous outgoing connections from the original mix node
1349+
// replace those with the outgoing connection from the new add node.
1350+
auto mixNodeOutput = mix2Node->getOutput("out");
1351+
auto mixNodeOutputConns = mixNodeOutput->getConnections();
1352+
for (auto conn : mixNodeOutputConns)
1353+
{
1354+
addNode->getOutput("out")->makeConnection(conn);
1355+
}
1356+
1357+
removeNode(mix1Node);
1358+
removeNode(mix2Node);
1359+
1360+
return true;
11301361
}
11311362

11321363

source/MaterialXGenShader/ShaderGraph.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
137137
/// Add a node to the graph
138138
void addNode(ShaderNodePtr node);
139139

140+
bool removeNode(ShaderNode* node);
141+
140142
/// Add input sockets from an interface element (nodedef, nodegraph or node)
141143
void addInputSockets(const InterfaceElement& elem, GenContext& context);
142144

@@ -166,7 +168,8 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
166168
/// Optimize the graph, removing redundant paths.
167169
void optimize(GenContext& context);
168170

169-
void optimizeMixBsdf(ShaderNode* node, GenContext& context);
171+
bool optimizeMixBsdf(ShaderNode* node, GenContext& context);
172+
bool optimizeMixMixBsdf(ShaderNode* node, GenContext& context);
170173

171174
/// Bypass a node for a particular input and output,
172175
/// effectively connecting the input's upstream connection

0 commit comments

Comments
 (0)