Understanding TorchScript Type System

Given the complexity of TorchScript language today and the dependencies it creates for its users, further evolution of TorchScript may need to take a more disciplined approach. So internally we started an effort to revamp the TorchScript language specification. The first step is to capture faithfully the language as it is implemented today (issue #50434), documents the rough edges identified during the process (issue #50444), and based on these understandings propose new features to the language.

In this thread we will focus on the discussion of the TorchScript type system. We would love to hear what you think.

TorchScript Types

The TorchScript type system consists of TSType and TSInstanceClass .

TSAllType := TSType | TSInstanceClass
TSType := TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType
TSInstanceClass := TSModuleType

TSType represents the majority of TorchScript types, that are composable and can be used in TorchScript type annotation. TSType can be further classified into

  • meta types, e.g., Any
  • primitive types, e.g., int , float , str
  • structural types, e.g., Optional[int] or List[MyClass]
  • nominal types (Python classes), e.g., MyClass (user-defined), torch.tensor (builtin)

TSInstanceClass are user-defined class types inferred partly from the object instance and partly from the class definition, such that instances of the same TSInstanceClass may not follow the same static type schema. Therefore TSInstanceClass cannot be used in TorchScript type annotation or be composed with TSType for type safety considerations. Currently TSInstanceClass consists of only TSModuleType , which represents torch.nn.Module and its subclasses.

TSInstanceClass represents user-defined subclasses of torch.nn.Module. It is put into a separate category because TSInstanceClass cannot be used in TorchScript type annotations (will be explained later). This special constraint of TSInstanceClass is a surprise to me and also the root-cause of issue #49650.

Meta Types

Meta types are so abstract that they are more like type constraints than concrete types. Currently TorchScript defines one meta-type, Any , that represents any TorchScript type.

TSMetaType := "Any"

Primitive Types

Primitive types represent a single type of value and go with a single pre-defined type name.

TSPrimitiveType := "int" | "float" | "double" | "bool" | "str" | "None"

Structural Types

Structural types are types that are structurally defined without a user-defined name (unlike nominal types), such as Future[int] . Structural types are composable with any TSType .

TSStructualType :=  TSTuple | TSNamedTuple | TSList | TSDict | 
                    TSOptional | TSFuture | TSRRef

TSTuple := "Tuple" "[" (TSType ",")+ "]"
TSNamedTuple := "namedtuple" "(" (TSType ",")+ ")"
TSList := "List" "[" TSType "]"
TSOptional := "Optional" "[" TSType "]"
TSFuture := "Future" "[" TSType "]"
TSRRef := "RRef" "[" TSType "]"
TSDict := "Dict" "[" KeyType "," TSType "]"
KeyType := "str" | "int" | "float" | "bool" | TensorType | "Any"

Nominal Types

Nominal TorchScript types are Python classes. They are called nominal because these types are declared with a custom name and are compared using class names. Nominal classes are further classified into the following categories:

TSNominalType := TSBuiltinClasses | TSCustomClass | TSEnum | TSModuleInterface

Among them, TSCustomClass and TSEnum must be compilable to TorchScript IR, thus are the most common cause of scripting failures. TSBuiltinClass does not require compilation, but their behaviors are only a (under-specified) subset of their Python counterparts, also a common cause of scripting failures.

Nominal types are the most complex types in the type system. They are hard to design correctly, easy to misuse, and difficult to explain to users, thus should be the focus of future extensions.

Builtin Class

Builtin nominal types are Python classes whose semantics are built into the TorchScript system, such as tensor types. TorchScript defines the semantics of these builtin nominal types, and often support a subset of the methods of its Python counterpart.

TSBuiltinClass := TSTensor | "torch.device" | "torch.stream" | "torch.dtype" | 
                  "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ...
TSTensor := "torch.tensor" and subclasses

Custom Class

Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable into TorchScript IR and subject to TorchScript type-checking rules.

TSClassDef := [ "@torch.jit.script" ]
              "class" ClassName "(" [TSInterfaceName] ":" 
                    ClassBodyDefinition

where

  • TSInterfaceName is the name of a TorchScript interface class
  • Classes must be new-style classes, as we use __new__() to construct them with pybind11.
  • Classes and instance attributes are statically typed, and instance attributes must be declared by assignments inside __init__()
  • Method overloading is not supported (i.e., cannot have multiple methods with the same method name)
  • ClassBodyDefinition must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules, i.e., all methods must be valid TorchScript functions and class attribute definitions are valid TorchScript statements

Enum Type

Like custom classes, semantics of enum type are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules.

TSEnumDef := "class" Identifier "(enum.Enum | TSEnumType)" ":"
                 ( MemberIdentifier "=" Value )+
                 ( MethodDefinition )*

where

  • Value must be TorchScript literals of type int, float, or str, and must be of the same TorchScript type
  • TSEnumType is the name of a TorchScript enum type. Similar to Python enum, TorchScript allows restricted Enum subclassing, that is, subclassing an enum is allowed only if the enum does not define any members.
  • MethodDefinition must be compilable to TorchScript IR and subject to TorchScript’s type checking rules

Module Interface Type

Module interface types are abstract classes that define only signatures of methods (or methods w/ empty bodies) and no instance or class attributes. As such interface types do not define constructors (i.e., __init__() ). They are designed to support some form of polymorphic binding (e.g., to allow subclasses of interfaces to bind to variables/parameters of interface types and resolve method attributes at runtime by name).

Currently we only allow to define module interface types (i.e., interfaces of torch.nn.Module )

TSInterface := "@torch.jit.interface"
               "class" Identifier "(torch.nn.Module)" ":"
               (
                    "def" MethodIdentifier "(" (MethodArgument)* ")" ["->" TSType]
                          "pass"
               )*

where

  • MethodArgument must be type annotated and pass TorchScript’s type checking rules
  • Method overloading is not supported (i.e., cannot have multiple methods with the same method name)

Note that from the implementation ( try_ann_to_type() in annotations.py ), it does not seem that interface types require compilation. Thus interface types should probably be considered built-in nominal types.

TorchScript Instance Class

TSInstanceClass are class types that are inferred from object instances created outside TorchScript. Although TSInstanceClass is named by the Python class of the object instance. The __init__() methods of the Python class are not considered as TorchScript methods, thus they do not have to comply to TorchScript’s static type checking rules.

Since the type schema of module instance class is constructed directly from an instance object (created outside the scope of TorchScript), rather than inferred from __init__() like custom classes. It is possible that two objects of the same instance class type follow two different type schemas.

In this sense, TSInstanceClass is not really a static type. Therefore, for type safety considerations, TSInstanceClass cannot be used in TorchScript type annotation or be composed with TSType .

Module Instance Class

TorchScript module type represents type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to forward()), the Python module class is treated as a module instance class so that init() of the Python module class are not subject to the type checking rules of TorchScript

TSModuleType := "class" Identifier "(torch.nn.Module)" ":"
                       ClassBodyDefinition

where

  • forward() or other methods decorated with @torch.jit.export must be compilable to TorchScript IR and subject to TorchScript’s type checking rules

Unlike custom classes, only the forward method and other methods decorated with @torch.jit.export of the module type need to be compilable to TorchScript IR. Most notably, __init__() is not considered a TorchScript method. Thus module type constructors should not be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into torch.jit.script(ModuleObj) .

Since TorchScript intentionally does not compiled the entire class definition of a module type, module type is not considered as a complete TorchScript type, and, for type-safety purposes, is not allowed in type annotation.

Using TorchScript Types

Static types are used for type annotation and type checking in TorchScript.

Type Checking

(Disclaimer: I did not get the chance to read the type checking codes in detail. So the following is not carefully validated against the implementation.)

From the static type checker’s perspective, it translates into checking the following conditions

  1. Types of any static expression within the scope of a TorchScript program can be determined statically;
  2. Variables, parameters, class and instance attributes must have the same static type throughout their life-time;
  3. A static type schema (incl. types of class/instance.method attributes) can be built for TorchScript classes so that instances of a TorchScript classes are guaranteed to follow the same schema; and
  4. All operators, methods, class/instance attributes are resolved to a concrete implementation in TorchScript.

As such, type-checking is applied to

  • type annotation construct : to check types used in annotation are TorchScript types
  • control-flow join : to check that type associated with symbols remains the same after control-flow join
  • constructors : to check that all __init__() produce the same static schema for instance attributes
  • parameter binding & assignment : to check that the type of argument or right-hand-side matches that of the formal parameter or left-hand-side
  • attribute resolution : to check that the attribute (incl. operators) is part of the static schema of the class/type.

Note that, type Any must be used carefully. Since Any is not a concrete type, accessing any attribute, operator, or method of Any will likely violate Condition #4. Likewise, method attributes of interface types must be handled specially to resolve to concrete implementations.

Type Annotation

To satisfy the above requirements, type annotations are need to seed static types into an otherwise dynamically typed Python program. TorchScript may obtain types through user annotation, auto-inference, or if all else fail assume the default type of TensorType.

In general

  • Parameters are type annotated, otherwise it is given the default type of Tensor;
  • Return types can be annotated, or automatically inferred;
  • Class attributes are type annotated, auto inferred if it is initialized, otherwise it is given the default type of Tensor;
  • Local variables are often auto-inferred, or in special cases can be type annotated via torch.jit.annotate()
  • Instance attributes are always initialized inside __init()__, therefore their types are always auto-inferred
  • self cannot be type annotated and its type is inferred

How to Improve TorchScript Type System

I hope I have captured the gist of the current state of the TorchScript type system. Now I would like to share my observations and some thoughts on how to improve it.

In TorchScript type system, nominal types are both the most complex and the least well-defined. As such, they are most likely to cause scripting failures, user confusion, or type-safety loopholes. Builtin and custom nominal types seem to suffer from completely opposite problems:

  • Builtin nominal types reply on an internalized representation of a subset of semantics of the selected Python classes. The gap between builtin TorchScript types and their Python counterparts are many and often unspecified, leading to many type-checking failures during attribute resolution. There are also ad hoc rules (such as module types not used in type annotation) and lax checking (e.g., all subclasses of torch.tensor are considered tensor types even for user-defined tensor subclasses);
  • Custom nominal types reply on compilation to be converted into TorchScript IRs. But the requirement of compiling the entire class definition can be too restrictive and lead to confusing type-checking errors at type annotation site.

Custom nominal types must be fully compiled, really?

Let’s start with a concrete example (issue #49650). This code fails to be scripted because " sub: TestSubModule" uses a TorchScript module type in type annotation, thus fails the type checker (see our previous explanation of why module type is not considered a complete TorchScript type).

class TestSubModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v = torch.rand((2, 3))

    def forward(self):
        return self.v

class TestModule(torch.nn.Module):
    sub: TestSubModule # this line fails annotation checking

    def __init__(self):
        super().__init__()
        self.sub = TestSubModule()

    def forward(self):
        return self.sub()

scripted_m = torch.jit.script(TestModule())

We previously explained that module type is not considered as a complete TorchScript type because __init__() of TorchScript module types are not considered TorchScript methods. However, the fact is that TestModule.forward() does not use __init__() of sub at all, but only forward() of sub , which can be compiled to TorchScript IR.

The requirement of compiling the entire class definition of custom TorchScript nominal types is evident in the type checking rules for annotation expressions in try_ann_to_type() from annotations.py:

if inspect.isclass(ann):
    qualified_name = _qualified_name(ann)
    if _get_script_class(qualified_name) is not None:
        return ClassType(qualified_name)
    ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
    if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
        torch.jit._script._recursive_compile_class(ann, loc)
        return ClassType(qualified_name) 

But do they have to be? The answer is No . We already have plenty partially internalized TorchScript types in the current type system. In fact, most builtin TorchScript (nominal) types are partially supported compared to their Python counterpart (see TorchScript unsupported PyTorch construct).

I would like to propose to relax the type-checking at the definition of custom nominal types. We simply consider any class as a TorchScript class type. When compiling the class definition, we keep track of whether methods or attributes are compilable to TorchScript IR (like builtin classes are). The real type-checking happens at attribute resolution , i.e., all methods or attributes of a TorchScript class used within the scope of a TorchScript program can be resolved to concrete TorchScript implementations.

Since __init__ methods also serve the purpose of creating a static schema for data attributes of a TorchScript class regardless of whether __init__ is invoked inside TorchScript or not, special design (to be designed) is needed when __init__ cannot be compiled to TorchScript IR. Such a design may also work for TorchScript module types (simplifying the type system).

Miscellaneous

Just want to briefly mention a couple of half-baked ideas.

  • Design annotations to further prune non-executed paths in a TorchScript program based on the execution context of the TorchScript program. In general, we want to tighten the compilation surface of TorchScript as much as possible because the more codes are compiled, the more likely type checking failures may arise. For instance, torch.jit.trace() is a way to tighten the TorchScript compilation surface by runtime specialization (although it is done in a way that may be hard to reason about correctness).
  • More refined design of Any -like meta types or user-defined interface types to increase the flexibility of the language.
  • Soften the edge between builtin and custom types. Builtin and custom types seem to be internalized in completely different ways, not clear how well TorchScript supports the in-between point such as overwriting builtin types w/ user-defined methods or subclassing builtin types in custom classes.
  • It seems that objects created outside the scope of TorchScript need to be type-checked at runtime.

Finally I want to say that TorchScript needs to maintain a working language specification as part of the development process (if not for the users). With the current complexity and an existing user-base, it will become increasingly difficult to make changes to the type system without having a model of the system that matches the implementation.

4 Likes

Awesome discussion of the type system. I’m learning a lot, thank you!
Two quick thoughts:

  • As someone who tries to do things to the PyTorch JIT sometimes, I’d probably like to get a clearer idea of where things live in PyTorch. I know this is difficult to achieve without cluttering the definition for everyone else (maybe sidenotes or something like that, potentially even toggleable in visibility). Also I think that “where the Type is valid/may occur” (expanding on the Using TorchScript types section, maybe) would be interesting (e.g. it used to be, but I would not know if that still is the case, that we had typed always None and so NoneType would not happen in a .graph you see from Python; a very minor thing: “Function” will might show up in a graph, but, as you discuss for the nominal types, every function is technically of its own type, so for a user-facing documentation, it might be nice to point that out).
  • The JIT has a PyObject type holding Python objects in TorchScript. I didn’t know that and I’m not entirely sure I know where it is intended to be used, but I’ve been looking at using it for implementing a Python fallback mechanism of sorts.

Best regards

Thomas

I’d probably like to get a clearer idea of where things live in PyTorch. I know this is difficult to achieve without cluttering the definition for everyone else (maybe sidenotes or something like that, potentially even toggleable in visibility)

Totally understand your frustration. This post is an initial draft of part of the revamped language spec. Since language spec should not be tangled with implementation, it is probably not the place to include where things are implemented kind of notes there. We do plan to include more design notes, comparison again Python, and good and bad examples in the spec.

At the same time, we are also building a “spec” for the TorchScript IR, which would help to create a very high-level (visualized) mental map of different components of the JIT from the lens of transformations of IRs and lowering/conversion across IRs (it turns out there are a couple of IRs inside the JIT). That might help a bit. We hope to write a similar discussion note on IRs soon.

“Function” will might show up in a graph, but, as you discuss for the nominal types, every function is technically of its own type, so for a user-facing documentation, it might be nice to point that out

That’s a very good point. Our original thoughts are to describe “external types” as those that can be used in type annotation, but perhaps the criterion should be broader and include types that are exposed to users (e.g., via graph). There are quite a few types defined in C10_FORALL_TYPES from jit_type_base.h that we did not describe in the type system (e.g., NumberType and `Function), need to check if they have been “exposed” in graph.

Types that can be used in type annotation should definitely be described in the language spec. Description of internal types is useful information to help users understand Graph, but they may not belong to language spec because we don’t want to create unnecessary dependency on a specific implementation. So we need to find a place for such information.

The JIT has a PyObject type holding Python objects in TorchScript.

Do you mean module instance types?