MaterialXWeb 0.0.2
Utilities for using MaterialX Packages with Web clients
Loading...
Searching...
No Matches
usdmtlx.py
1import MaterialX as mx
2from pxr import Usd, UsdShade, Sdf, UsdGeom, Gf
3
4def mapMtlxToUsdShaderNotation(name):
5 '''
6 Utility to map from a MaterialX shader notation to Usd.
7 It would be easier if the same notation was used.
8 '''
9 if name == 'surfaceshader':
10 name = 'surface'
11 elif name == 'displacementshader':
12 name = 'displacement'
13 elif name == 'volumshader':
14 name = 'volume'
15 return name
16
17def emitUsdConnections(node, stage, rootPath):
18 """
19 Emit connections between MaterialX elements as Usd connections for
20 a given MaterialX node.
21
22 Paramters:
23 - node :
24 MaterialX node to examine
25 - stage :
26 Usd stage to write connection to
27 """
28 if not node:
29 return
30
31 materialPath = None
32 if node.getType() == 'material':
33 materialPath = node.getName()
34
35 for valueElement in node.getActiveValueElements():
36 isInput = valueElement.isA(mx.Input)
37 isOutput = valueElement.isA(mx.Output)
38 if isInput or isOutput:
39
40 interfacename = ''
41
42 # Find out what type of element is connected to upstream:
43 # node, nodegraph, or interface input.
44 mtlxConnection = valueElement.getAttribute('nodename')
45 if not mtlxConnection:
46 mtlxConnection = valueElement.getAttribute('nodegraph')
47 if not isOutput:
48 if not mtlxConnection:
49 mtlxConnection = valueElement.getAttribute('interfacename')
50 interfacename = mtlxConnection
51
52 connectionPath = ''
53 if mtlxConnection:
54
55 # Handle input connection by searching for the appropriate parent node.
56 # - If it's an interface input we want the parent nodegraph. Otherwise
57 # we want the node or nodegraph specified above.
58 # - If the parent path is the root (getNamePath() is empty), then this is to
59 # nodes at the root document level.
60 if isInput:
61 parent = node.getParent()
62 if parent.getNamePath():
63 if interfacename:
64 connectionPath = rootPath + parent.getNamePath()
65 else:
66 connectionPath = rootPath + parent.getNamePath() + '/' + mtlxConnection
67 else:
68 # The connectio is to a prim at the root level so insert a '/' identifier
69 # as getNamePath() will return an empty string at the root Document level.
70 if interfacename:
71 connectionPath = rootPath
72 else:
73 connectionPath = rootPath + mtlxConnection
74
75 # Handle output connection by looking for sibling elements
76 else:
77 parent = node.getParent()
78
79 # Connection is to sibling under the same nodegraph
80 if node.isA(mx.NodeGraph):
81 connectionPath = rootPath + node.getNamePath() + '/' + mtlxConnection
82 else:
83 # Connection is to a nodegraph parent of the current node
84 if parent.getNamePath():
85 connectionPath = rootPath + parent.getNamePath() + '/' + mtlxConnection
86 # Connection is to the root document.
87 else:
88 connectionPath = rootPath + mtlxConnection
89
90 # Find the source prim
91 # Assumes that the source is either a nodegraph, a material or a shader
92 connectionPath = connectionPath.removesuffix('/')
93 sourcePrim = None
94 sourcePort = 'out'
95 source = stage.GetPrimAtPath(connectionPath)
96 if not source:
97 if materialPath:
98 connectionPath = '/' + materialPath + connectionPath
99 source = stage.GetPrimAtPath(connectionPath)
100 if not source:
101 source = stage.GetPrimAtPath('/' + materialPath)
102 if source:
103 if source.IsA(UsdShade.Material):
104 sourcePrim = UsdShade.Material(source)
105 elif source.IsA(UsdShade.NodeGraph):
106 sourcePrim = UsdShade.NodeGraph(source)
107 elif source.IsA(UsdShade.Shader):
108 sourcePrim = UsdShade.Shader(source)
109
110 # Special case handle interface input vs an output
111 if interfacename:
112 sourcePort = interfacename
113 else:
114 sourcePort = valueElement.getAttribute('output')
115 if not sourcePort:
116 sourcePort = 'out'
117 if sourcePort:
118 mtlxConnection = mtlxConnection + '. Port:' + sourcePort
119
120 else:
121 print('> Failed to find source at path:', connectionPath)
122
123 # Find destination prim and port and make the appropriate connection.
124 # Assumes that the destination is either a nodegraph, a material or a shader
125 destInput = None
126 if sourcePrim:
127 dest = stage.GetPrimAtPath(rootPath + node.getNamePath())
128 if not dest:
129 print('> Failed to find dest at path:', rootPath + node.getNamePath())
130 else:
131 destPort = None
132 portName = valueElement.getName()
133 destNode = None
134 if dest.IsA(UsdShade.Material):
135 destNode = UsdShade.Material(dest)
136 elif dest.IsA(UsdShade.NodeGraph):
137 destNode = UsdShade.NodeGraph(dest)
138 elif dest.IsA(UsdShade.Shader):
139 destNode = UsdShade.Shader(dest)
140 else:
141 print('> Encountered unsupport destinion type')
142
143 # Find downstream port (input or output)
144 if destNode:
145 if isInput:
146 # Map from MaterialX to Usd connection syntax
147 if dest.IsA(UsdShade.Material):
148 portName = mapMtlxToUsdShaderNotation(portName)
149 portName = 'mtlx:' + portName
150 destPort = destNode.GetOutput(portName)
151 else:
152 destPort = destNode.GetInput(portName)
153 else:
154 destPort = destNode.GetOutput(portName)
155
156 # Make connection to interface input, or node/nodegraph output
157 if destPort:
158 if interfacename:
159 interfaceInput = sourcePrim.GetInput(sourcePort)
160 if interfaceInput:
161 if not destPort.ConnectToSource(interfaceInput):
162 print('> Failed to connect: ', source.GetPrimPath(), '-->', destPort.GetFullName())
163 else:
164 sourcePrimAPI = sourcePrim.ConnectableAPI()
165 if not destPort.ConnectToSource(sourcePrimAPI, sourcePort):
166 print('> Failed to connect: ', source.GetPrimPath(), '-->', destPort.GetFullName())
167 else:
168 print('> Failed to find destination port:', portName)
169
170
171def mapMtxToUsdType(mtlxType):
172 """
173 Map a MaterialX type to an Usd Sdf type
174
175 Parameters:
176 -----------
177 - mtxType : string
178 MaterialX type
179 """
180 mtlxUsdMap = dict()
181 mtlxUsdMap['filename'] = Sdf.ValueTypeNames.Asset
182 mtlxUsdMap['string'] = Sdf.ValueTypeNames.String
183 mtlxUsdMap['boolean'] = Sdf.ValueTypeNames.Bool
184 mtlxUsdMap['integer'] = Sdf.ValueTypeNames.Int
185 mtlxUsdMap['float'] = Sdf.ValueTypeNames.Float
186 mtlxUsdMap['color3'] = Sdf.ValueTypeNames.Color3f
187 mtlxUsdMap['color4'] = Sdf.ValueTypeNames.Color4f
188 mtlxUsdMap['vector2'] = Sdf.ValueTypeNames.Float2
189 mtlxUsdMap['vector3'] = Sdf.ValueTypeNames.Vector3f
190 mtlxUsdMap['vector4'] = Sdf.ValueTypeNames.Float4
191 mtlxUsdMap['surfaceshader'] = Sdf.ValueTypeNames.Token
192
193 if mtlxType in mtlxUsdMap:
194 return mtlxUsdMap[mtlxType]
195 return Sdf.ValueTypeNames.Token
196
197def mapMtxToUsdValue(mtlxType, mtlxValue):
198 """
199 Map a MaterialX value of a given type to a Usd value.
200 Note: Not all types are included here.
201 """
202 usdValue = '__'
203 if mtlxType == 'float':
204 usdValue = mtlxValue
205 elif mtlxType == 'integer':
206 usdValue = mtlxValue
207 elif mtlxType == 'boolean':
208 usdValue = mtlxValue
209 elif mtlxType == 'string':
210 usdValue = mtlxValue
211 elif mtlxType == 'filename':
212 usdValue = mtlxValue
213 elif mtlxType == 'vector2':
214 usdValue = Gf.Vec2f( mtlxValue[0], mtlxValue[1] )
215 elif mtlxType == 'color3' or mtlxType == 'vector3':
216 usdValue = Gf.Vec3f( mtlxValue[0], mtlxValue[1], mtlxValue[2] )
217 elif mtlxType == 'color4' or mtlxType == 'vector4':
218 usdValue = Gf.Vec4f( mtlxValue[0], mtlxValue[1], mtlxValue[2], mtlxValue[3] )
219
220 return usdValue
221
222def emitUsdValueElements(node, usdNode, emitAllValueElements):
223 """
224 Emit MaterialX value elements in Usd.
225
226 Parameters
227 ------------
228 node:
229 MaterialX node with value elements to scan
230 usdNode:
231 UsdShade node to create value elements on.
232 emitAllValueElements: bool
233 Emit value elements based on node definition, even if not specified on node instance.
234 """
235 if not node:
236 return
237
238 isMaterial = node.getType() == 'material'
239
240 # Instantiate with all the nodedef inputs (if emitAllValueELements is True).
241 # Note that outputs are always created.
242 nodedef = node.getNodeDef()
243 if nodedef and not isMaterial:
244 for valueElement in nodedef.getActiveValueElements():
245 if valueElement.isA(mx.Input):
246 if emitAllValueElements:
247 mtlxType = valueElement.getType()
248 usdType = mapMtxToUsdType(mtlxType)
249
250 portName = valueElement.getName()
251 usdInput = usdNode.CreateInput(portName, usdType)
252
253 if len(valueElement.getValueString()) > 0:
254 mtlxValue = valueElement.getValue()
255 usdValue = mapMtxToUsdValue(mtlxType, mtlxValue)
256 if usdValue != '__':
257 usdInput.Set(usdValue)
258
259 elif not isMaterial and valueElement.isA(mx.Output):
260 usdOutput = usdNode.CreateOutput(valueElement.getName(), mapMtxToUsdType(valueElement.getType()))
261
262 else:
263 print('- Skip mapping of definition element: ', valueElement.getName(), '. Type: ', valueElement.getCategory())
264
265 # From the given instance add inputs and outputs and set values.
266 # This may override the default value specified on the definition.
267 for valueElement in node.getActiveValueElements():
268 if valueElement.isA(mx.Input):
269 mtlxType = valueElement.getType()
270 usdType = mapMtxToUsdType(mtlxType)
271 portName = valueElement.getName()
272 if isMaterial:
273 # Map from Materials to Usd notation
274 portName = mapMtlxToUsdShaderNotation(portName)
275 usdInput = usdNode.CreateOutput('mtlx:' + portName, usdType)
276 else:
277 usdInput = usdNode.CreateInput(portName, usdType)
278
279 # Set value. Note that we check the length of the value string
280 # instead of getValue() as a 0 value will be skipped.
281 if len(valueElement.getValueString()) > 0:
282 mtlxValue = valueElement.getValue()
283 usdValue = mapMtxToUsdValue(mtlxType, mtlxValue)
284 if usdValue != '__':
285 usdInput.Set(usdValue)
286
287 elif not isMaterial and valueElement.isA(mx.Output):
288 usdOutput = usdNode.GetInput(valueElement.getName())
289 if not usdOutput:
290 usdOutput = usdNode.CreateOutput(valueElement.getName(), mapMtxToUsdType(valueElement.getType()))
291
292 else:
293 print('- Skip mapping of element: ', valueElement.getNamePath(), '. Type: ', valueElement.getCategory())
294
295
296def moveChild(newParent, child):
297 newChild = newParent.addChildOfCategory(child.getCategory(), child.getName())
298 print(newChild.getNamePath())
299 newChild.copyContentFrom(child)
300 oldParent = child.getParent()
301 oldParent.removeChild(child.getName())
302
303def emitUsdShaderGraph(doc, stage, mxnodes, emitAllValueElements):
304 """
305 Emit Usd shader graph to a given stage from a list of MaterialX nodes.
306
307 Parameters
308 ------------
309 doc:
310 MaterialX source document
311 stage:
312 Usd target stage
313 mxnodes:
314 MaterialX shader nodes.
315 emitAllValueElements: bool
316 Emit value elements based on node definition, even if not specified on node instance.
317 """
318 materialPath = None
319 print('Stage:', stage)
320 if not stage:
321 return
322
323 for v in mxnodes:
324 elem = doc.getDescendant(v)
325 if elem.getType() == 'material':
326 materialPath = elem.getName()
327 break
328
329 # Emit Usd nodes
330 for v in mxnodes:
331 elem = doc.getDescendant(v)
332
333 # Note that MaterialX does not use absolute path notation while Usd
334 # does. This will result in an error when trying set the path
335 usdPath = '/' + elem.getNamePath()
336
337 nodeDef = None
338 usdNode = None
339 if elem.getType() == 'material':
340 usdNode = UsdShade.Material.Define(stage, usdPath)
341 elif elem.isA(mx.Node):
342 nodeDef = elem.getNodeDef()
343 if materialPath:
344 elemPath = '/' + materialPath + usdPath
345 else:
346 elemPath = usdPath
347 usdNode = UsdShade.Shader.Define(stage, elemPath)
348 elif elem.isA(mx.NodeGraph):
349 if materialPath:
350 elemPath = '/' + materialPath + usdPath
351 else:
352 elemPath = usdPath
353 usdNode = UsdShade.NodeGraph.Define(stage, elemPath)
354
355 if usdNode:
356 if nodeDef:
357 usdNode.SetShaderId(nodeDef.getName())
358 emitUsdValueElements(elem, usdNode, emitAllValueElements)
359
360 # Emit connections between Usd nodes
361 for v in mxnodes:
362 elem = doc.getDescendant(v)
363 usdPath = '/' + elem.getNamePath()
364
365 if elem.getType() == 'material':
366 emitUsdConnections(elem, stage, '/')
367 elif elem.isA(mx.Node):
368 if materialPath:
369 emitUsdConnections(elem, stage, '/' + materialPath + '/')
370 elif elem.isA(mx.NodeGraph):
371 if materialPath:
372 emitUsdConnections(elem, stage, '/' + materialPath + '/')
373
374def findMaterialXNodes(doc):
375 """
376 Find all nodes in a MaterialX document
377 """
378 visitedNodes = []
379 treeIter = doc.traverseTree()
380 for elem in treeIter:
381 path = elem.getNamePath()
382 if path in visitedNodes:
383 continue
384 visitedNodes.append(path)
385 return visitedNodes
386
387def convertMtlxToUsd(doc, emitAllValueElements):
388 """
389 Read in a MaterialX file and emit it to a new Usd Stage
390 Dump results for display and save to usda file.
391
392 Parameters:
393 -----------
394 mtlxFileName : string
395 Name of file containing MaterialX document. Assumed to end in ".mtlx"
396 emitAllValueElements: bool
397 Emit value elements based on node definition, even if not specified on node instance.
398 """
399 stage = Usd.Stage.CreateInMemory()
400
401 #doc = mx.createDocument()
402 #mtlxFilePath = mx.FilePath(mtlxFileName)
403 #if not mtlxFilePath.exists():
404 # print('Failed to read file: ', mtlxFilePath.asString())
405 # return
406
407 # Find nodes to transform before importing the definition library
408 #mx.readFromXmlFile(doc, mtlxFileName)
409 mxnodes = findMaterialXNodes(doc)
410 stdlib = mx.createDocument()
411 libFiles = []
412 searchPath = mx.getDefaultDataSearchPath()
413 libFiles = mx.loadLibraries(mx.getDefaultDataLibraryFolders(), searchPath, stdlib)
414 doc.importLibrary(stdlib)
415
416 # Translate
417 emitUsdShaderGraph(doc, stage, mxnodes, emitAllValueElements)
418
419 #usdFile = mtlxFileName.removesuffix('.mtlx')
420 #usdFile = usdFile + '.usda'
421 #print('Export USD file: ', usdFile)
422 #stage.Export(usdFile, False)
423
424 stageString = stage.GetRootLayer().ExportToString()
425
426 return stageString
427