Given the complexity of TorchScript language today and the dependencies it creates for its users, further evolution of TorchScript may need to take a more disciplined approach. So internally we started an effort to revamp the TorchScript language specification. The first step is to capture faithfully the language as it is implemented today (issue #50434), documents the rough edges identified during the process (issue #50444), and based on these understandings propose new features to the language.
In this thread we will focus on the discussion of the TorchScript type system. We would love to hear what you think.
TorchScript Types
The TorchScript type system consists of TSType
and TSInstanceClass
.
TSAllType := TSType | TSInstanceClass
TSType := TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSInstanceClass := TSModuleType
TSType
represents the majority of TorchScript types, that are composable and can be used in TorchScript type annotation. TSType
can be further classified into
- meta types, e.g.,
Any
- primitive types, e.g.,
int
,float
,str
- structural types, e.g.,
Optional[int]
orList[MyClass]
- nominal types (Python classes), e.g.,
MyClass
(user-defined),torch.tensor
(builtin)
TSInstanceClass
are user-defined class types inferred partly from the object instance and partly from the class definition, such that instances of the same TSInstanceClass
may not follow the same static type schema. Therefore TSInstanceClass
cannot be used in TorchScript type annotation or be composed with TSType
for type safety considerations. Currently TSInstanceClass
consists of only TSModuleType
, which represents torch.nn.Module
and its subclasses.
TSInstanceClass
represents user-defined subclasses of torch.nn.Module
. It is put into a separate category because TSInstanceClass
cannot be used in TorchScript type annotations (will be explained later). This special constraint of TSInstanceClass
is a surprise to me and also the root-cause of issue #49650.
Meta Types
Meta types are so abstract that they are more like type constraints than concrete types. Currently TorchScript defines one meta-type, Any
, that represents any TorchScript type.
TSMetaType := "Any"
Primitive Types
Primitive types represent a single type of value and go with a single pre-defined type name.
TSPrimitiveType := "int" | "float" | "double" | "bool" | "str" | "None"
Structural Types
Structural types are types that are structurally defined without a user-defined name (unlike nominal types), such as Future[int]
. Structural types are composable with any TSType
.
TSStructualType := TSTuple | TSNamedTuple | TSList | TSDict |
TSOptional | TSFuture | TSRRef
TSTuple := "Tuple" "[" (TSType ",")+ "]"
TSNamedTuple := "namedtuple" "(" (TSType ",")+ ")"
TSList := "List" "[" TSType "]"
TSOptional := "Optional" "[" TSType "]"
TSFuture := "Future" "[" TSType "]"
TSRRef := "RRef" "[" TSType "]"
TSDict := "Dict" "[" KeyType "," TSType "]"
KeyType := "str" | "int" | "float" | "bool" | TensorType | "Any"
Nominal Types
Nominal TorchScript types are Python classes. They are called nominal because these types are declared with a custom name and are compared using class names. Nominal classes are further classified into the following categories:
TSNominalType := TSBuiltinClasses | TSCustomClass | TSEnum | TSModuleInterface
Among them, TSCustomClass
and TSEnum
must be compilable to TorchScript IR, thus are the most common cause of scripting failures. TSBuiltinClass
does not require compilation, but their behaviors are only a (under-specified) subset of their Python counterparts, also a common cause of scripting failures.
Nominal types are the most complex types in the type system. They are hard to design correctly, easy to misuse, and difficult to explain to users, thus should be the focus of future extensions.
Builtin Class
Builtin nominal types are Python classes whose semantics are built into the TorchScript system, such as tensor types. TorchScript defines the semantics of these builtin nominal types, and often support a subset of the methods of its Python counterpart.
TSBuiltinClass := TSTensor | "torch.device" | "torch.stream" | "torch.dtype" |
"torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor := "torch.tensor" and subclasses
Custom Class
Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable into TorchScript IR and subject to TorchScript type-checking rules.
TSClassDef := [ "@torch.jit.script" ]
"class" ClassName "(" [TSInterfaceName] ":"
ClassBodyDefinition
where
-
TSInterfaceName
is the name of a TorchScript interface class - Classes must be new-style classes, as we use
__new__()
to construct them with pybind11. - Classes and instance attributes are statically typed, and instance attributes must be declared by assignments inside
__init__()
- Method overloading is not supported (i.e., cannot have multiple methods with the same method name)
-
ClassBodyDefinition
must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules, i.e., all methods must be valid TorchScript functions and class attribute definitions are valid TorchScript statements
Enum Type
Like custom classes, semantics of enum type are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules.
TSEnumDef := "class" Identifier "(enum.Enum | TSEnumType)" ":"
( MemberIdentifier "=" Value )+
( MethodDefinition )*
where
-
Value
must be TorchScript literals of typeint
,float
, orstr
, and must be of the same TorchScript type -
TSEnumType
is the name of a TorchScript enum type. Similar to Python enum, TorchScript allows restricted Enum subclassing, that is, subclassing an enum is allowed only if the enum does not define any members. -
MethodDefinition
must be compilable to TorchScript IR and subject to TorchScript’s type checking rules
Module Interface Type
Module interface types are abstract classes that define only signatures of methods (or methods w/ empty bodies) and no instance or class attributes. As such interface types do not define constructors (i.e., __init__()
). They are designed to support some form of polymorphic binding (e.g., to allow subclasses of interfaces to bind to variables/parameters of interface types and resolve method attributes at runtime by name).
Currently we only allow to define module interface types (i.e., interfaces of torch.nn.Module
)
TSInterface := "@torch.jit.interface"
"class" Identifier "(torch.nn.Module)" ":"
(
"def" MethodIdentifier "(" (MethodArgument)* ")" ["->" TSType]
"pass"
)*
where
-
MethodArgument
must be type annotated and pass TorchScript’s type checking rules - Method overloading is not supported (i.e., cannot have multiple methods with the same method name)
Note that from the implementation ( try_ann_to_type()
in annotations.py
), it does not seem that interface types require compilation. Thus interface types should probably be considered built-in nominal types.
TorchScript Instance Class
TSInstanceClass
are class types that are inferred from object instances created outside TorchScript. Although TSInstanceClass
is named by the Python class of the object instance. The __init__()
methods of the Python class are not considered as TorchScript methods, thus they do not have to comply to TorchScript’s static type checking rules.
Since the type schema of module instance class is constructed directly from an instance object (created outside the scope of TorchScript), rather than inferred from __init__()
like custom classes. It is possible that two objects of the same instance class type follow two different type schemas.
In this sense, TSInstanceClass
is not really a static type. Therefore, for type safety considerations, TSInstanceClass
cannot be used in TorchScript type annotation or be composed with TSType
.
Module Instance Class
TorchScript module type represents type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to forward()), the Python module class is treated as a module instance class so that init() of the Python module class are not subject to the type checking rules of TorchScript
TSModuleType := "class" Identifier "(torch.nn.Module)" ":"
ClassBodyDefinition
where
-
forward()
or other methods decorated with@torch.jit.export
must be compilable to TorchScript IR and subject to TorchScript’s type checking rules
Unlike custom classes, only the forward
method and other methods decorated with @torch.jit.export
of the module type need to be compilable to TorchScript IR. Most notably, __init__()
is not considered a TorchScript method. Thus module type constructors should not be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into torch.jit.script(ModuleObj)
.
Since TorchScript intentionally does not compiled the entire class definition of a module type, module type is not considered as a complete TorchScript type, and, for type-safety purposes, is not allowed in type annotation.
Using TorchScript Types
Static types are used for type annotation and type checking in TorchScript.
Type Checking
(Disclaimer: I did not get the chance to read the type checking codes in detail. So the following is not carefully validated against the implementation.)
From the static type checker’s perspective, it translates into checking the following conditions
- Types of any static expression within the scope of a TorchScript program can be determined statically;
- Variables, parameters, class and instance attributes must have the same static type throughout their life-time;
- A static type schema (incl. types of class/instance.method attributes) can be built for TorchScript classes so that instances of a TorchScript classes are guaranteed to follow the same schema; and
- All operators, methods, class/instance attributes are resolved to a concrete implementation in TorchScript.
As such, type-checking is applied to
- type annotation construct : to check types used in annotation are TorchScript types
- control-flow join : to check that type associated with symbols remains the same after control-flow join
-
constructors : to check that all
__init__()
produce the same static schema for instance attributes - parameter binding & assignment : to check that the type of argument or right-hand-side matches that of the formal parameter or left-hand-side
- attribute resolution : to check that the attribute (incl. operators) is part of the static schema of the class/type.
Note that, type Any
must be used carefully. Since Any
is not a concrete type, accessing any attribute, operator, or method of Any
will likely violate Condition #4. Likewise, method attributes of interface types must be handled specially to resolve to concrete implementations.
Type Annotation
To satisfy the above requirements, type annotations are need to seed static types into an otherwise dynamically typed Python program. TorchScript may obtain types through user annotation, auto-inference, or if all else fail assume the default type of TensorType
.
In general
- Parameters are type annotated, otherwise it is given the default type of
Tensor
; - Return types can be annotated, or automatically inferred;
- Class attributes are type annotated, auto inferred if it is initialized, otherwise it is given the default type of
Tensor
; - Local variables are often auto-inferred, or in special cases can be type annotated via
torch.jit.annotate()
- Instance attributes are always initialized inside
__init()__
, therefore their types are always auto-inferred -
self
cannot be type annotated and its type is inferred
How to Improve TorchScript Type System
I hope I have captured the gist of the current state of the TorchScript type system. Now I would like to share my observations and some thoughts on how to improve it.
In TorchScript type system, nominal types are both the most complex and the least well-defined. As such, they are most likely to cause scripting failures, user confusion, or type-safety loopholes. Builtin and custom nominal types seem to suffer from completely opposite problems:
- Builtin nominal types reply on an internalized representation of a subset of semantics of the selected Python classes. The gap between builtin TorchScript types and their Python counterparts are many and often unspecified, leading to many type-checking failures during attribute resolution. There are also ad hoc rules (such as module types not used in type annotation) and lax checking (e.g., all subclasses of
torch.tensor
are considered tensor types even for user-defined tensor subclasses); - Custom nominal types reply on compilation to be converted into TorchScript IRs. But the requirement of compiling the entire class definition can be too restrictive and lead to confusing type-checking errors at type annotation site.
Custom nominal types must be fully compiled, really?
Let’s start with a concrete example (issue #49650). This code fails to be scripted because " sub: TestSubModule"
uses a TorchScript module type in type annotation, thus fails the type checker (see our previous explanation of why module type is not considered a complete TorchScript type).
class TestSubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.v = torch.rand((2, 3))
def forward(self):
return self.v
class TestModule(torch.nn.Module):
sub: TestSubModule # this line fails annotation checking
def __init__(self):
super().__init__()
self.sub = TestSubModule()
def forward(self):
return self.sub()
scripted_m = torch.jit.script(TestModule())
We previously explained that module type is not considered as a complete TorchScript type because __init__()
of TorchScript module types are not considered TorchScript methods. However, the fact is that TestModule.forward()
does not use __init__()
of sub
at all, but only forward()
of sub
, which can be compiled to TorchScript IR.
The requirement of compiling the entire class definition of custom TorchScript nominal types is evident in the type checking rules for annotation expressions in try_ann_to_type()
from annotations.py:
if inspect.isclass(ann):
qualified_name = _qualified_name(ann)
if _get_script_class(qualified_name) is not None:
return ClassType(qualified_name)
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
torch.jit._script._recursive_compile_class(ann, loc)
return ClassType(qualified_name)
But do they have to be? The answer is No . We already have plenty partially internalized TorchScript types in the current type system. In fact, most builtin TorchScript (nominal) types are partially supported compared to their Python counterpart (see TorchScript unsupported PyTorch construct).
I would like to propose to relax the type-checking at the definition of custom nominal types. We simply consider any class as a TorchScript class type. When compiling the class definition, we keep track of whether methods or attributes are compilable to TorchScript IR (like builtin classes are). The real type-checking happens at attribute resolution , i.e., all methods or attributes of a TorchScript class used within the scope of a TorchScript program can be resolved to concrete TorchScript implementations.
Since __init__
methods also serve the purpose of creating a static schema for data attributes of a TorchScript class regardless of whether __init__
is invoked inside TorchScript or not, special design (to be designed) is needed when __init__
cannot be compiled to TorchScript IR. Such a design may also work for TorchScript module types (simplifying the type system).
Miscellaneous
Just want to briefly mention a couple of half-baked ideas.
- Design annotations to further prune non-executed paths in a TorchScript program based on the execution context of the TorchScript program. In general, we want to tighten the compilation surface of TorchScript as much as possible because the more codes are compiled, the more likely type checking failures may arise. For instance,
torch.jit.trace()
is a way to tighten the TorchScript compilation surface by runtime specialization (although it is done in a way that may be hard to reason about correctness). - More refined design of
Any
-like meta types or user-defined interface types to increase the flexibility of the language. - Soften the edge between builtin and custom types. Builtin and custom types seem to be internalized in completely different ways, not clear how well TorchScript supports the in-between point such as overwriting builtin types w/ user-defined methods or subclassing builtin types in custom classes.
- It seems that objects created outside the scope of TorchScript need to be type-checked at runtime.
Finally I want to say that TorchScript needs to maintain a working language specification as part of the development process (if not for the users). With the current complexity and an existing user-base, it will become increasingly difficult to make changes to the type system without having a model of the system that matches the implementation.