linux/tools/net/sunrpc/xdrgen/xdr_ast.py
Chuck Lever ce5a75d993 xdrgen: XDR width for union types
Not yet complete.

The tool doesn't do any math yet. Thus, even though the maximum XDR
width of a union is the width of the union enumerator plus the width
of its largest arm, we're using the sum of all the elements of the
union for the moment.

This means that buffer size requirements are overestimated, and that
the generated maxsize macro cannot yet be used for determining data
element alignment in the XDR buffer.

Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
2024-11-11 13:42:05 -05:00

753 lines
20 KiB
Python

#!/usr/bin/env python3
# ex: set filetype=python:
"""Define and implement the Abstract Syntax Tree for the XDR language."""
import sys
from typing import List
from dataclasses import dataclass
from lark import ast_utils, Transformer
from lark.tree import Meta
this_module = sys.modules[__name__]
big_endian = []
excluded_apis = []
header_name = "none"
public_apis = []
structs = set()
pass_by_reference = set()
constants = {}
def xdr_quadlen(val: str) -> int:
"""Return integer XDR width of an XDR type"""
if val in constants:
octets = constants[val]
else:
octets = int(val)
return int((octets + 3) / 4)
symbolic_widths = {
"void": ["XDR_void"],
"bool": ["XDR_bool"],
"int": ["XDR_int"],
"unsigned_int": ["XDR_unsigned_int"],
"long": ["XDR_long"],
"unsigned_long": ["XDR_unsigned_long"],
"hyper": ["XDR_hyper"],
"unsigned_hyper": ["XDR_unsigned_hyper"],
}
# Numeric XDR widths are tracked in a dictionary that is keyed
# by type_name because sometimes a caller has nothing more than
# the type_name to use to figure out the numeric width.
max_widths = {
"void": 0,
"bool": 1,
"int": 1,
"unsigned_int": 1,
"long": 1,
"unsigned_long": 1,
"hyper": 2,
"unsigned_hyper": 2,
}
@dataclass
class _XdrAst(ast_utils.Ast):
"""Base class for the XDR abstract syntax tree"""
@dataclass
class _XdrIdentifier(_XdrAst):
"""Corresponds to 'identifier' in the XDR language grammar"""
symbol: str
@dataclass
class _XdrValue(_XdrAst):
"""Corresponds to 'value' in the XDR language grammar"""
value: str
@dataclass
class _XdrConstantValue(_XdrAst):
"""Corresponds to 'constant' in the XDR language grammar"""
value: int
@dataclass
class _XdrTypeSpecifier(_XdrAst):
"""Corresponds to 'type_specifier' in the XDR language grammar"""
type_name: str
c_classifier: str = ""
@dataclass
class _XdrDefinedType(_XdrTypeSpecifier):
"""Corresponds to a type defined by the input specification"""
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return [get_header_name().upper() + "_" + self.type_name + "_sz"]
def __post_init__(self):
if self.type_name in structs:
self.c_classifier = "struct "
symbolic_widths[self.type_name] = self.symbolic_width()
@dataclass
class _XdrBuiltInType(_XdrTypeSpecifier):
"""Corresponds to a built-in XDR type"""
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return symbolic_widths[self.type_name]
@dataclass
class _XdrDeclaration(_XdrAst):
"""Base class of XDR type declarations"""
@dataclass
class _XdrFixedLengthOpaque(_XdrDeclaration):
"""A fixed-length opaque declaration"""
name: str
size: str
template: str = "fixed_length_opaque"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return xdr_quadlen(self.size)
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return ["XDR_QUADLEN(" + self.size + ")"]
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrVariableLengthOpaque(_XdrDeclaration):
"""A variable-length opaque declaration"""
name: str
maxsize: str
template: str = "variable_length_opaque"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 1 + xdr_quadlen(self.maxsize)
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
widths = ["XDR_unsigned_int"]
if self.maxsize != "0":
widths.append("XDR_QUADLEN(" + self.maxsize + ")")
return widths
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrString(_XdrDeclaration):
"""A (NUL-terminated) variable-length string declaration"""
name: str
maxsize: str
template: str = "string"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 1 + xdr_quadlen(self.maxsize)
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
widths = ["XDR_unsigned_int"]
if self.maxsize != "0":
widths.append("XDR_QUADLEN(" + self.maxsize + ")")
return widths
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrFixedLengthArray(_XdrDeclaration):
"""A fixed-length array declaration"""
name: str
spec: _XdrTypeSpecifier
size: str
template: str = "fixed_length_array"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return xdr_quadlen(self.size) * max_widths[self.spec.type_name]
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
item_width = " + ".join(symbolic_widths[self.spec.type_name])
return ["(" + self.size + " * (" + item_width + "))"]
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrVariableLengthArray(_XdrDeclaration):
"""A variable-length array declaration"""
name: str
spec: _XdrTypeSpecifier
maxsize: str
template: str = "variable_length_array"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 1 + (xdr_quadlen(self.maxsize) * max_widths[self.spec.type_name])
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
widths = ["XDR_unsigned_int"]
if self.maxsize != "0":
item_width = " + ".join(symbolic_widths[self.spec.type_name])
widths.append("(" + self.maxsize + " * (" + item_width + "))")
return widths
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrOptionalData(_XdrDeclaration):
"""An 'optional_data' declaration"""
name: str
spec: _XdrTypeSpecifier
template: str = "optional_data"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 1
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return ["XDR_bool"]
def __post_init__(self):
structs.add(self.name)
pass_by_reference.add(self.name)
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrBasic(_XdrDeclaration):
"""A 'basic' declaration"""
name: str
spec: _XdrTypeSpecifier
template: str = "basic"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return max_widths[self.spec.type_name]
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return symbolic_widths[self.spec.type_name]
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrVoid(_XdrDeclaration):
"""A void declaration"""
name: str = "void"
template: str = "void"
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 0
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return []
@dataclass
class _XdrConstant(_XdrAst):
"""Corresponds to 'constant_def' in the grammar"""
name: str
value: str
def __post_init__(self):
if self.value not in constants:
constants[self.name] = int(self.value, 0)
@dataclass
class _XdrEnumerator(_XdrAst):
"""An 'identifier = value' enumerator"""
name: str
value: str
def __post_init__(self):
if self.value not in constants:
constants[self.name] = int(self.value, 0)
@dataclass
class _XdrEnum(_XdrAst):
"""An XDR enum definition"""
name: str
minimum: int
maximum: int
enumerators: List[_XdrEnumerator]
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return 1
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return ["XDR_int"]
def __post_init__(self):
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrStruct(_XdrAst):
"""An XDR struct definition"""
name: str
fields: List[_XdrDeclaration]
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
width = 0
for field in self.fields:
width += field.max_width()
return width
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
widths = []
for field in self.fields:
widths += field.symbolic_width()
return widths
def __post_init__(self):
structs.add(self.name)
pass_by_reference.add(self.name)
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrPointer(_XdrAst):
"""An XDR pointer definition"""
name: str
fields: List[_XdrDeclaration]
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
width = 1
for field in self.fields[0:-1]:
width += field.max_width()
return width
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
widths = []
widths += ["XDR_bool"]
for field in self.fields[0:-1]:
widths += field.symbolic_width()
return widths
def __post_init__(self):
structs.add(self.name)
pass_by_reference.add(self.name)
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _XdrTypedef(_XdrAst):
"""An XDR typedef"""
declaration: _XdrDeclaration
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
return self.declaration.max_width()
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
return self.declaration.symbolic_width()
def __post_init__(self):
if isinstance(self.declaration, _XdrBasic):
new_type = self.declaration
if isinstance(new_type.spec, _XdrDefinedType):
if new_type.spec.type_name in pass_by_reference:
pass_by_reference.add(new_type.name)
max_widths[new_type.name] = self.max_width()
symbolic_widths[new_type.name] = self.symbolic_width()
@dataclass
class _XdrCaseSpec(_XdrAst):
"""One case in an XDR union"""
values: List[str]
arm: _XdrDeclaration
template: str = "case_spec"
@dataclass
class _XdrDefaultSpec(_XdrAst):
"""Default case in an XDR union"""
arm: _XdrDeclaration
template: str = "default_spec"
@dataclass
class _XdrUnion(_XdrAst):
"""An XDR union"""
name: str
discriminant: _XdrDeclaration
cases: List[_XdrCaseSpec]
default: _XdrDeclaration
def max_width(self) -> int:
"""Return width of type in XDR_UNITS"""
max_width = 0
for case in self.cases:
if case.arm.max_width() > max_width:
max_width = case.arm.max_width()
if self.default:
if self.default.arm.max_width() > max_width:
max_width = self.default.arm.max_width()
return 1 + max_width
def symbolic_width(self) -> List:
"""Return list containing XDR width of type's components"""
max_width = 0
for case in self.cases:
if case.arm.max_width() > max_width:
max_width = case.arm.max_width()
width = case.arm.symbolic_width()
if self.default:
if self.default.arm.max_width() > max_width:
max_width = self.default.arm.max_width()
width = self.default.arm.symbolic_width()
return symbolic_widths[self.discriminant.name] + width
def __post_init__(self):
structs.add(self.name)
pass_by_reference.add(self.name)
max_widths[self.name] = self.max_width()
symbolic_widths[self.name] = self.symbolic_width()
@dataclass
class _RpcProcedure(_XdrAst):
"""RPC procedure definition"""
name: str
number: str
argument: _XdrTypeSpecifier
result: _XdrTypeSpecifier
@dataclass
class _RpcVersion(_XdrAst):
"""RPC version definition"""
name: str
number: str
procedures: List[_RpcProcedure]
@dataclass
class _RpcProgram(_XdrAst):
"""RPC program definition"""
name: str
number: str
versions: List[_RpcVersion]
@dataclass
class _Pragma(_XdrAst):
"""Empty class for pragma directives"""
@dataclass
class Definition(_XdrAst, ast_utils.WithMeta):
"""Corresponds to 'definition' in the grammar"""
meta: Meta
value: _XdrAst
@dataclass
class Specification(_XdrAst, ast_utils.AsList):
"""Corresponds to 'specification' in the grammar"""
definitions: List[Definition]
class ParseToAst(Transformer):
"""Functions that transform productions into AST nodes"""
def identifier(self, children):
"""Instantiate one _XdrIdentifier object"""
return _XdrIdentifier(children[0].value)
def value(self, children):
"""Instantiate one _XdrValue object"""
if isinstance(children[0], _XdrIdentifier):
return _XdrValue(children[0].symbol)
return _XdrValue(children[0].children[0].value)
def constant(self, children):
"""Instantiate one _XdrConstantValue object"""
match children[0].data:
case "decimal_constant":
value = int(children[0].children[0].value, base=10)
case "hexadecimal_constant":
value = int(children[0].children[0].value, base=16)
case "octal_constant":
value = int(children[0].children[0].value, base=8)
return _XdrConstantValue(value)
def type_specifier(self, children):
"""Instantiate one _XdrTypeSpecifier object"""
if isinstance(children[0], _XdrIdentifier):
name = children[0].symbol
return _XdrDefinedType(type_name=name)
name = children[0].data.value
return _XdrBuiltInType(type_name=name)
def constant_def(self, children):
"""Instantiate one _XdrConstant object"""
name = children[0].symbol
value = children[1].value
return _XdrConstant(name, value)
# cel: Python can compute a min() and max() for the enumerator values
# so that the generated code can perform proper range checking.
def enum(self, children):
"""Instantiate one _XdrEnum object"""
enum_name = children[0].symbol
i = 0
enumerators = []
body = children[1]
while i < len(body.children):
name = body.children[i].symbol
value = body.children[i + 1].value
enumerators.append(_XdrEnumerator(name, value))
i = i + 2
return _XdrEnum(enum_name, 0, 0, enumerators)
def fixed_length_opaque(self, children):
"""Instantiate one _XdrFixedLengthOpaque declaration object"""
name = children[0].symbol
size = children[1].value
return _XdrFixedLengthOpaque(name, size)
def variable_length_opaque(self, children):
"""Instantiate one _XdrVariableLengthOpaque declaration object"""
name = children[0].symbol
if children[1] is not None:
maxsize = children[1].value
else:
maxsize = "0"
return _XdrVariableLengthOpaque(name, maxsize)
def string(self, children):
"""Instantiate one _XdrString declaration object"""
name = children[0].symbol
if children[1] is not None:
maxsize = children[1].value
else:
maxsize = "0"
return _XdrString(name, maxsize)
def fixed_length_array(self, children):
"""Instantiate one _XdrFixedLengthArray declaration object"""
spec = children[0]
name = children[1].symbol
size = children[2].value
return _XdrFixedLengthArray(name, spec, size)
def variable_length_array(self, children):
"""Instantiate one _XdrVariableLengthArray declaration object"""
spec = children[0]
name = children[1].symbol
if children[2] is not None:
maxsize = children[2].value
else:
maxsize = "0"
return _XdrVariableLengthArray(name, spec, maxsize)
def optional_data(self, children):
"""Instantiate one _XdrOptionalData declaration object"""
spec = children[0]
name = children[1].symbol
return _XdrOptionalData(name, spec)
def basic(self, children):
"""Instantiate one _XdrBasic object"""
spec = children[0]
name = children[1].symbol
return _XdrBasic(name, spec)
def void(self, children):
"""Instantiate one _XdrVoid declaration object"""
return _XdrVoid()
def struct(self, children):
"""Instantiate one _XdrStruct object"""
name = children[0].symbol
fields = children[1].children
last_field = fields[-1]
if (
isinstance(last_field, _XdrOptionalData)
and name == last_field.spec.type_name
):
return _XdrPointer(name, fields)
return _XdrStruct(name, fields)
def typedef(self, children):
"""Instantiate one _XdrTypedef object"""
new_type = children[0]
return _XdrTypedef(new_type)
def case_spec(self, children):
"""Instantiate one _XdrCaseSpec object"""
values = []
for item in children[0:-1]:
values.append(item.value)
arm = children[-1]
return _XdrCaseSpec(values, arm)
def default_spec(self, children):
"""Instantiate one _XdrDefaultSpec object"""
arm = children[0]
return _XdrDefaultSpec(arm)
def union(self, children):
"""Instantiate one _XdrUnion object"""
name = children[0].symbol
body = children[1]
discriminant = body.children[0].children[0]
cases = body.children[1:-1]
default = body.children[-1]
return _XdrUnion(name, discriminant, cases, default)
def procedure_def(self, children):
"""Instantiate one _RpcProcedure object"""
result = children[0]
name = children[1].symbol
argument = children[2]
number = children[3].value
return _RpcProcedure(name, number, argument, result)
def version_def(self, children):
"""Instantiate one _RpcVersion object"""
name = children[0].symbol
number = children[-1].value
procedures = children[1:-1]
return _RpcVersion(name, number, procedures)
def program_def(self, children):
"""Instantiate one _RpcProgram object"""
name = children[0].symbol
number = children[-1].value
versions = children[1:-1]
return _RpcProgram(name, number, versions)
def pragma_def(self, children):
"""Instantiate one _Pragma object"""
directive = children[0].children[0].data
match directive:
case "big_endian_directive":
big_endian.append(children[1].symbol)
case "exclude_directive":
excluded_apis.append(children[1].symbol)
case "header_directive":
global header_name
header_name = children[1].symbol
case "public_directive":
public_apis.append(children[1].symbol)
case _:
raise NotImplementedError("Directive not supported")
return _Pragma()
transformer = ast_utils.create_transformer(this_module, ParseToAst())
def transform_parse_tree(parse_tree):
"""Transform productions into an abstract syntax tree"""
return transformer.transform(parse_tree)
def get_header_name() -> str:
"""Return header name set by pragma header directive"""
return header_name