jittor 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers:
#   Dun Liang <randonlang@gmail.com>.
#   Meng-Hao Guo <guomenghao1997@gmail.com>
#
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.9.0'
from jittor_utils import lock
with lock.lock_scope():
    ori_int = int
    ori_float = float
    ori_bool = bool
    from . import compiler
    from .compiler import LOG, has_cuda
    from .compiler import compile_custom_ops, compile_custom_op
    import jittor_core
    import jittor_core as core
    from jittor_core import *
    from jittor_core.ops import *
    from . import compile_extern
    from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size
    if core.get_device_count() == 0:
        has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
    from .compile_extern import cudnn, curand, cublas, cufft
    from .init_cupy import numpy2cupy

from typing import List, Tuple
import contextlib
import numpy as np
from collections import OrderedDict
from collections.abc import Sequence, Mapping
import types
import pickle
import hashlib
import sys, os
import traceback

if "SKEY" in os.environ:
    import jittor_utils.student_queue

[文档]def dfs_to_numpy(x): if isinstance(x, list): for i in range(len(x)): x[i] = dfs_to_numpy(x[i]) elif isinstance(x, dict): for k in x: x[k] = dfs_to_numpy(x[k]) elif isinstance(x, Var): return x.numpy() return x
[文档]def safepickle(obj, path): if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin"): from jittor_utils.save_pytorch import save_pytorch save_pytorch(path, obj) return # Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. # ref: <https://docs.python.org/3/library/pickle.html> obj = dfs_to_numpy(obj) s = pickle.dumps(obj, 4) checksum = hashlib.sha1(s).digest() s += bytes(checksum) s += b"HCAJSLHD" with open(path, 'wb') as f: f.write(s)
def _load_pkl(s, path): try: return pickle.loads(s) except Exception as e: msg = str(e) msg += f"\nPath: \"{path}\"" if "trunc" in msg: msg += "\nThis file maybe corrupted, please consider remove it" \ " and re-download." raise RuntimeError(msg) def _upload(path, url, jk, tdir=""): tdir = tdir + '/' if tdir != "" else "" prefix = f"https://cg.cs.tsinghua.edu.cn/jittor/{tdir}assets" if url.startswith("jittorhub://"): url = url.replace("jittorhub://", prefix+"/build/checkpoints/") assert url.startswith(prefix) suffix = url[len(prefix):] dir_suffix = "/".join(suffix.split("/")[:-1]) jkey = flags.cache_path+"/_jkey" with open(jkey, 'w') as f: f.write(jk) assert os.system(f"chmod 600 \"{jkey}\"") == 0 print(dir_suffix) assert os.system(f"s""s""h"f" -i \"{jkey}\" jittor" "@" "166" f".111.68.30 mkdir -p Documents/jittor-blog/{tdir}assets{dir_suffix}") == 0 assert os.system(f"s""c""p"+f" -i \"{jkey}\" \"{path}\" jittor" "@" "166" f".111.68.30:Documents/jittor-blog/{tdir}assets{suffix}") == 0 assert os.system(f"s""s""h"f" -i \"{jkey}\" jittor" "@" "166" ".111.68.30 Documents/jittor-blog.git/hooks/post-update") == 0
[文档]def safeunpickle(path): if path.startswith("jittorhub://"): path = path.replace("jittorhub://", f"https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/") if path.startswith("https:") or path.startswith("http:"): base = path.split("/")[-1] fname = os.path.join(compiler.ck_path, base) from jittor_utils.misc import download_url_to_local download_url_to_local(path, base, compiler.ck_path, None) path = fname if not (path.endswith(".pth") or path.endswith(".pkl") or path.endswith(".pt")): return path if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin") : from jittor_utils.load_pytorch import load_pytorch model_dict = load_pytorch(path) return model_dict with open(path, "rb") as f: s = f.read() if not s.endswith(b"HCAJSLHD"): return _load_pkl(s, path) checksum = s[-28:-8] s = s[:-28] if hashlib.sha1(s).digest() != checksum: raise ValueError("Pickle checksum does not match! path: "+path, " This file maybe corrupted, please consider remove it" " and re-download.") return _load_pkl(s, path)
class _call_no_record_scope: def __enter__(self): pass def __exit__(self, *exc): pass def __call__(self, func): def inner(*args, **kw): with self: ret = func(*args, **kw) return ret return inner
[文档]class flag_scope(_call_no_record_scope): def __init__(self, **jt_flags): self.jt_flags = jt_flags def __enter__(self): flags_bk = self.flags_bk = {} try: for k,v in self.jt_flags.items(): origin = getattr(flags, k) flags_bk[k] = origin # merge dict attrs if isinstance(origin, dict): for ok, ov in origin.items(): if ok not in v: v[ok] = ov setattr(flags, k, v) except: self.__exit__() raise def __exit__(self, *exc): for k,v in self.flags_bk.items(): setattr(flags, k, v)
[文档]class no_grad(flag_scope): ''' no_grad scope, all variable created inside this scope will stop grad. Example:: import jittor as jt with jt.no_grad(): ... ''' def __init__(self, **jt_flags): self.jt_flags = jt_flags jt_flags["no_grad"] = 1
[文档]class enable_grad(flag_scope): ''' enable_grad scope, all variable created inside this scope will start grad. Example:: import jittor as jt with jt.enable_grad(): ... ''' def __init__(self, **jt_flags): self.jt_flags = jt_flags jt_flags["no_grad"] = 0
single_log_capture = None
[文档]class log_capture_scope(_call_no_record_scope): """log capture scope Example:: with jt.log_capture_scope(log_v=0) as logs: LOG.v("...") print(logs) """ def __init__(self, **jt_flags): jt_flags["use_parallel_op_compiler"] = 0 self.fs = flag_scope(**jt_flags) def __enter__(self): global single_log_capture assert not single_log_capture single_log_capture = True self.logs = [] LOG.log_capture_start() try: self.fs.__enter__() if "log_v" in self.fs.jt_flags: LOG.log_v = self.fs.jt_flags["log_v"] return self.logs except: LOG.log_capture_stop() single_log_capture = None raise def __exit__(self, *exc): global single_log_capture self.fs.__exit__(*exc) if "log_v" in self.fs.jt_flags: LOG.log_v = flags.log_v LOG.log_capture_stop() self.logs.extend(LOG.log_capture_read()) single_log_capture = None
[文档]class profile_scope(_call_no_record_scope): """ profile scope example:: with jt.profile_scope() as report: ...... print(report) """ def __init__(self, warmup=0, rerun=0, **jt_flags): self.fs = flag_scope(**jt_flags) self.warmup = warmup self.rerun = rerun def __enter__(self): assert not flags.profiler_enable self.report = [] try: self.fs.__enter__() profiler.start(self.warmup, self.rerun) return self.report except: profiler.stop() raise def __exit__(self, *exc): profiler.stop() self.report.extend(profiler.report()) self.fs.__exit__(*exc)
[文档]class profile_mark(_call_no_record_scope): def __init__(self, mark_name: str): ''' profiler mark is used for profiling part of code, Example:: a = jt.rand(1000,1000) b = jt.rand(1000,1000) jt.sync_all() results = [] with jt.profile_scope() as rep: results.append(jt.matmul(a, b)) with jt.profile_mark("mark1"): results.append(jt.matmul(a, b)) with jt.profile_mark("mark2"): results.append(jt.matmul(a, b)) with jt.profile_mark("mark3"): results.append(jt.matmul(a, b)) results.append(jt.matmul(a, b)) Output:: Total time: 46.8ms Total Memory Access: 57.2MB [Mark mark3] time: 9ms [Mark mark2] time: 8.28ms [Mark mark1] time: 17.7ms ''' self.mark_name = mark_name def __enter__(self): self.options = flags.compile_options new_options = flags.compile_options prev_marks = "_marks:" for x in self.options: if x.startswith(prev_marks): prev_marks = x del new_options[x] new_marks = prev_marks + self.mark_name + ',' new_options[new_marks] = 1 flags.compile_options = new_options def __exit__(self, *exc): flags.compile_options = self.options
class __single_process_scope: def __init__(self, rank=0): self.rank = rank def __enter__(self): global in_mpi self.bk_in_mpi = in_mpi if mpi: self.bk_mpi_state = mpi.get_state() if not in_mpi: return True ret = self.rank == mpi.world_rank() in_mpi = compile_extern.in_mpi = False mpi.set_state(False) return ret def __exit__(self, *exc): global in_mpi in_mpi = compile_extern.in_mpi = self.bk_in_mpi if mpi: mpi.set_state(self.bk_mpi_state)
[文档]def single_process_scope(rank=0): """ single_process_scope Code in this scope will only be executed by single process. All the mpi code inside this scope will have not affect. mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1, example:: @jt.single_process_scope(rank=0) def xxx(): ... """ def outer(func): def inner(*args, **kw): ret = None sync_all() with __single_process_scope(rank) as flag: if flag: ret = func(*args, **kw) return ret return inner return outer
[文档]def clean(): import gc # make sure python do a full collection gc.collect() core.gc()
cast = unary Var.cast = Var.cast
[文档]def array(data, dtype=None): ''' Constructs a jittor Var from a number, List, numpy array or another jittor Var. :param data: The data to initialize the Var. :type data: number, list, numpy.ndarray, or jittor.Var. :param dtype: The data type of the Var. If None, the data type will be inferred from the data. :type dtype: str, jittor type-cast function, or None. ---------------- Example:: >>> jt.array(1) jt.Var([1], dtype=int32) >>> jt.array([0, 2.71, 3.14]) jt.Var([0. 2.71 3.14], dtype=float32) >>> jt.array(np.arange(4, dtype=np.uint8)) jt.Var([0 1 2 3], dtype=uint8) ''' if isinstance(data, core.Var): if dtype is None: ret = data.clone() else: ret = cast(data, dtype) elif dtype is not None: if isinstance(dtype, NanoString): dtype = str(dtype) elif callable(dtype): dtype = dtype.__name__ with jt.flag_scope(auto_convert_64_to_32=0): ret = ops.array(np.array(data, dtype)) else: ret = ops.array(data) # TODO: move those code to core amp_reg = jt.flags.amp_reg if amp_reg and ret.numel() != 1 and ret.dtype.is_float(): if amp_reg & 16: if amp_reg & 1: if ret.dtype != "float32": return ret.float32() elif amp_reg & 2: if ret.dtype != "float16": return ret.float16() return ret
[文档]def random(shape, dtype="float32", type="uniform"): ''' Constructs a random jittor Var. :param shape: The shape of the random Var. :type shape: list or tuple. :param dtype: The data type of the random Var. :type dtype: str, jittor type-cast function, or None. :param type: The random distribution, can be 'uniform' or 'normal'. :type type: str ---------------- Example:: >>> jt.random((2, 3)) jt.Var([[0.96788853 0.28334728 0.30482838] [0.46107793 0.62798643 0.03457401]], dtype=float32) ''' # TODO: move those code to core if dtype in ["float16", "bfloat16"]: # TODO: make curand support fp16 ret = ops.random(shape, "float32", type).cast(dtype) else: ret = ops.random(shape, dtype, type) amp_reg = jt.flags.amp_reg if amp_reg: if amp_reg & 16: if amp_reg & 1: if ret.dtype != "float32": return ret.float32() elif amp_reg & 2: if ret.dtype != "float16": return ret.float16() return ret
[文档]def float_auto(x): if jt.flags.amp_reg & 2: return x.float16() return x.float32()
Var.float_auto = float_auto
[文档]def array64(data, dtype=None): with jt.flag_scope(auto_convert_64_to_32=0): return array(data, dtype)
[文档]def grad(loss, targets, retain_graph=True): if type(targets) == core.Var: return core.grad(loss, [targets], retain_graph)[0] return core.grad(loss, targets, retain_graph)
[文档]def liveness_info(): return { "hold_vars": core.number_of_hold_vars(), "lived_vars": core.number_of_lived_vars(), "lived_ops": core.number_of_lived_ops(), }
[文档]def ones(*shape, dtype="float32"): ''' Constructs a jittor Var with all elements set to 1. :param shape: The shape of the output Var. :type shape: list or tuple. :param dtype: The data type of the output Var. :type dtype: str, jittor type-cast function, or None. :return: The output Var. :rtype: jittor.Var ''' if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): dtype = shape[-1] shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] return unary(1, dtype).broadcast(shape)
[文档]def new_ones(x, size): return ones(size, x.dtype)
Var.new_ones = new_ones
[文档]def ones_like(x): ''' Constructs a jittor Var with all elements set to 1 and shape same with x. :param x: The reference jittor Var. :type x: jt.Var :return: The output Var. :rtype: jittor.Var ''' return ones(x.shape,x.dtype)
[文档]def zeros(*shape, dtype="float32"): ''' Constructs a jittor Var with all elements set to 0. :param shape: The shape of the output Var. :type shape: list or tuple. :param dtype: The data type of the output Var. :type dtype: str, jittor type-cast function, or None. :return: The output Var. :rtype: jittor.Var ''' if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): dtype = shape[-1] shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] return unary(0, dtype).broadcast(shape)
[文档]def new_zeros(x, size): return zeros(size, x.dtype)
Var.new_zeros = new_zeros
[文档]def empty(*shape, dtype="float32"): if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): dtype = shape[-1] shape = shape[:-1] if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] return ops.empty(shape, dtype)
[文档]def new_empty(x, size): return empty(size, x.dtype)
Var.new_empty = new_empty
[文档]def full(shape,val,dtype="float32"): ''' Constructs a jittor Var with all elements set to val. :param shape: The shape of the output Var. :type shape: list or tuple. :param val: The value of the output Var. :type val: number. :param dtype: The data type of the output Var. Defaults to jt.float32. :type dtype: str, jittor type-cast function, or None. :return: The output Var. :rtype: jittor.Var ''' if not isinstance(shape, (NanoVector, Sequence)): shape = (shape,) return unary(val, dtype).broadcast(shape)
[文档]def new_full(x, size, val): return full(size, val, x.dtype)
Var.new_full = new_full
[文档]def ne(x,y): return x!=y
Var.ne = ne
[文档]def full_like(x, val, dtype=None) -> Var: ''' Constructs a jittor Var with all elements set to val and shape same with x. :param x: The reference jittor Var. :type x: jt.Var. :param val: The value of the output Var. :type val: number. :param dtype: if None, the dtype of the output is the same as x. Otherwise, use the specified dtype. Defaults to None. :type dtype: str, optional :return: The output Var. :rtype: jittor.Var ''' if dtype is None: dtype = x.dtype return full(x.shape, val, dtype)
[文档]def zeros_like(x, dtype=None) -> Var: ''' Constructs a jittor Var with all elements set to 0 and shape same with x. :param x: The reference jittor Var. :type x: jt.Var :param dtype: if None, the dtype of the output is the same as x. Otherwise, use the specified dtype. Defaults to None. :type dtype: str, optional :return: The output Var. :rtype: jittor.Var ''' if dtype is None: dtype = x.dtype return zeros(x.shape, dtype)
flags = core.Flags()
[文档]def var(x, dim=None, dims=None, unbiased=False, keepdims=False): """ return the sample variance. If unbiased is True, Bessel's correction will be used. :param x: the input jittor Var. :type x: jt.Var. :param dim: the dimension to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. :type dim: int. :param dims: the dimensions to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. :type dims: tuple of int. :param unbiased: if True, Bessel's correction will be used. :type unbiased: bool. :param keepdim: if True, the output shape is same as input shape except for the dimension in dim. :type keepdim: bool. Example:: >>> a = jt.rand(3) >>> a jt.Var([0.79613626 0.29322362 0.19785859], dtype=float32) >>> a.var() jt.Var([0.06888353], dtype=float32) >>> a.var(unbiased=True) jt.Var([0.10332529], dtype=float32) """ shape = x.shape new_shape = list(x.shape) assert dim is None or dims is None, "dim and dims can not be both set" if dim is None and dims is None: dims = list(range(len(shape))) elif dim is not None: dims = [dim] mean = jt.mean(x, dims, keepdims=True) mean = jt.broadcast(mean, shape) n = 1 for d in dims: n *= shape[d] new_shape[d] = 1 sqr = (x - mean) ** 2 sqr = jt.sum(sqr, dims=dims, keepdims=False) if unbiased: n -= 1 sqr /= n if keepdims: sqr = sqr.view(new_shape) return sqr
Var.var = var
[文档]def std(x): matsize=1 for i in x.shape: matsize *= i out=(x-x.mean()).sqr().sum() out=out/(matsize-1) out=out.maximum(1e-6).sqrt() return out
Var.std = std
[文档]def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False): keepdim = keepdim or keepdims assert p==1 or p==2 if p==1: return x.abs().sum(dim, keepdim) if p==2: return (x.sqr()).sum(dim, keepdim).maximum(eps).sqrt()
Var.norm = norm origin_reshape = reshape
[文档]def reshape(x, *shape): if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)): shape = shape[0] return origin_reshape(x, shape)
reshape.__doc__ = origin_reshape.__doc__ Var.view = Var.reshape = view = reshape origin_transpose = transpose
[文档]def transpose(x, *dim): if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)): dim = dim[0] elif len(dim) == 2: axes = list(range(x.ndim)) a, b = dim axes[a], axes[b] = axes[b], axes[a] dim = axes return origin_transpose(x, dim)
transpose.__doc__ = origin_transpose.__doc__ Var.transpose = Var.permute = permute = transpose
[文档]def flatten(input, start_dim=0, end_dim=-1): '''flatten dimentions by reshape''' in_shape = input.shape start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" out_shape = [] for i in range(0,start_dim,1): out_shape.append(in_shape[i]) dims = 1 for i in range(start_dim, end_dim+1, 1): dims *= in_shape[i] out_shape.append(dims) for i in range(end_dim+1,len(in_shape),1): out_shape.append(in_shape[i]) return input.reshape(out_shape)
Var.flatten = flatten Var.detach_inplace = Var.start_grad
[文档]def detach(x): return x.detach()
[文档]def unsqueeze(x, dim): shape = list(x.shape) if dim < 0: dim += len(shape) + 1 assert dim <= len(shape) return x.reshape(shape[:dim] + [1] + shape[dim:])
Var.unsqueeze = unsqueeze
[文档]def squeeze(x, dim=None): shape = list(x.shape) if dim is None: new_shape = [s for s in shape if s > 1] return x.reshape(new_shape) else: if dim < 0: dim += len(shape) assert dim < len(shape) and dim >= 0 assert shape[dim] == 1 return x.reshape(shape[:dim] + shape[dim+1:])
Var.squeeze = squeeze
[文档]def clamp(x, min_v=None, max_v=None): if x.shape[0]==0: return x if min_v is not None and max_v is not None: assert min_v <= max_v if min_v is not None: x = x.maximum(min_v) if max_v is not None: x = x.minimum(max_v) return x
Var.clamp = clamp
[文档]def clamp_(x, min_v=None, max_v=None): ''' In-place version of clamp(). Args: x (Jittor Var): the input var min_v ( Number or Var, optional) - lower-bound of clamp range max_v ( Number or Var, optional) - upper-bound of clamp range Return: x itself after clamp. ''' return x.assign(x.clamp(min_v=min_v, max_v=max_v))
Var.clamp_ = clamp_
[文档]def outer(x, y): ''' Returns the outer product of two 1-D vectors. :param x: the input Var. :type x: jt.Var, numpy array, or python sequence. :param y: the input Var. :type y: jt.Var, numpy array, or python sequence. Example:: >>> x = jt.arange(3) >>> y = jt.arange(4) >>> jt.outer(x, y) jt.Var([[0 0 0 0] [0 1 2 3] [0 2 4 6]], dtype=int32) >>> x.outer(y) jt.Var([[0 0 0 0] [0 1 2 3] [0 2 4 6]], dtype=int32) ''' return jt.multiply(x.unsqueeze(1), y.unsqueeze(0))
Var.outer = outer
[文档]def erfinv_(x): ''' In-place version of erfinv(). ''' return x.assign(x.erfinv())
Var.erfinv_ = erfinv_
[文档]def erf_(x): ''' In-place version of erf(). ''' return x.assign(x.erf())
Var.erf_ = erf_
[文档]def abs_(x): ''' In-place version of abs(). ''' return x.assign(x.abs())
Var.abs_ = abs_
[文档]def sigmoid_(x): ''' In-place version of sigmoid(). ''' return x.assign(x.sigmoid())
Var.sigmoid_ = sigmoid_
[文档]def sqrt_(x): ''' In-place version of sqrt(). ''' return x.assign(x.sqrt())
Var.sqrt_ = sqrt_
[文档]def add_(x, y): ''' In-place version of add(). ''' return x.assign(x.add(y))
Var.add_ = add_
[文档]def multiply_(x, y): ''' In-place version of multiply(). ''' return x.assign(x.multiply(y))
Var.multiply_ = multiply_
[文档]def type_as(a, b): return a.unary(op=b.dtype)
Var.type_as = type_as Var.astype = Var.cast
[文档]def masked_fill(x, mask, value): return jt.ternary(mask, value, x)
Var.masked_fill = masked_fill
[文档]def sqr(x): return x*x
Var.sqr = sqr
[文档]def pow(x, y): ''' computes x^y, element-wise. This operation is equivalent to ``x ** y``. :param x: the first input. :type x: a python number or jt.Var. :param y: the second input. :type y: a python number or jt.Var. ''' if isinstance(x,Var) and isinstance(y, (ori_int, ori_float)) and y == 2: return x.sqr() return core.ops.pow(x, y)
Var.pow = Var.__pow__ = pow
[文档]def argmax(x: Var, dim: int, keepdims:bool=False): ''' Returns the indices and values of the maximum elements along the specified dimension. :param x: the input Var. :type x: jt.Var, numpy array, or python sequence. :param dim: the dimension to reduce. :type dim: int. :param keepdims: whether the output Var has dim retained or not. Defaults to False :type keepdims: bool, optional Example:: >>> a = jt.randn((2, 4)) >>> a jt.Var([[-0.33272865 -0.4951588 1.4128606 0.13734372] [-1.633469 0.19593953 -0.7803732 -0.5260756 ]], dtype=float32) >>> a.argmax(dim=0) (jt.Var([0 1 0 0], dtype=int32), jt.Var([-0.33272865 0.19593953 1.4128606 0.13734372], dtype=float32)) >>> a.argmax(dim=1) (jt.Var([2 1], dtype=int32), jt.Var([1.4128606 0.19593953], dtype=float32)) ''' if dim is None: dim = 0 x = x.flatten() return jt.arg_reduce(x, "max", dim, keepdims)
Var.argmax = argmax
[文档]def argmin(x, dim: int, keepdims:bool=False): ''' Returns the indices and values of the minimum elements along the specified dimension. :param x: the input Var. :type x: jt.Var, numpy array, or python sequence. :param dim: the dimension to reduce. :type dim: int. :param keepdims: whether the output Var has dim retained or not. Defaults to False :type keepdims: bool, optional Example:: >>> a = jt.randn((2, 4)) >>> a jt.Var([[-0.33272865 -0.4951588 1.4128606 0.13734372] [-1.633469 0.19593953 -0.7803732 -0.5260756 ]], dtype=float32) >>> a.argmin(dim=0) (jt.Var([1 0 1 1], dtype=int32), jt.Var([-1.633469 -0.4951588 -0.7803732 -0.5260756], dtype=float32)) >>> a.argmin(dim=1) (jt.Var([1 0], dtype=int32), jt.Var([-0.4951588 -1.633469 ], dtype=float32)) ''' return jt.arg_reduce(x, "min", dim, keepdims)
Var.argmin = argmin
[文档]def randn(*size, dtype="float32", requires_grad=True) -> Var: ''' samples random numbers from a standard normal distribution. :param size: shape of the output. :type size: int or a sequence of int :param dtype: data type, defaults to "float32". :type dtype: str, optional :param requires_grad: whether to enable gradient back-propgation, defaults to True. :type requires_grad: bool, optional Example:: >>> jt.randn(3) jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32) >>> jt.randn(2, 3) jt.Var([[-0.15989183 -1.5010914 0.5476955 ] [-0.612632 -1.1471151 -1.1879086 ]], dtype=float32) ''' if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] arr = jt.random(size, dtype, "normal") if not requires_grad: return arr.stop_grad() return arr
[文档]def rand(*size, dtype="float32", requires_grad=True) -> Var: ''' samples random numbers from a uniform distribution on the interval [0, 1). :param size: shape of the output. :type size: int or a sequence of int :param dtype: data type, defaults to "float32". :type dtype: str, optional :param requires_grad: whether to enable gradient back-propgation. defaults to True. :type requires_grad: bool, optional Example:: >>> jt.rand(3) jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32) >>> jt.rand(2, 3) jt.Var([[0.96414304 0.3519264 0.8268017 ] [0.05658621 0.04449705 0.86190987]], dtype=float32) ''' if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] arr = jt.random(size, dtype) if not requires_grad: return arr.stop_grad() return arr
[文档]def rand_like(x, dtype=None) -> Var: ''' samples random values from standard uniform distribution with the same shape as x. :param x: reference variable. :type x: jt.Var :param dtype: if None, the dtype of the output is the same as x. Otherwise, use the specified dtype. Defaults to None. :type dtype: str, optional Example:: >>> x = jt.zeros((2, 3)) >>> jt.rand_like(x) jt.Var([[0.6164821 0.21476883 0.61959815] [0.58626485 0.35345772 0.5638483 ]], dtype=float32) ''' if dtype is None: dtype = x.dtype return jt.random(x.shape, dtype)
[文档]def randn_like(x, dtype=None) -> Var: ''' samples random values from standard normal distribution with the same shape as x. :param x: reference variable. :type x: jt.Var :param dtype: if None, the dtype of the output is the same as x. Otherwise, use the specified dtype. Defaults to None. :type dtype: str, optional Example:: >>> x = jt.zeros((2, 3)) >>> jt.randn_like(x) jt.Var([[-1.1647032 0.34847224 -1.3061888 ] [ 1.068085 -0.34366122 0.13172573]], dtype=float32) ''' if dtype is None: dtype = x.dtype return jt.random(x.shape, x.dtype, "normal")
[文档]def randint(low, high=None, shape=(1,), dtype="int32") -> Var: ''' samples random integers from a uniform distribution on the interval [low, high). :param low: lowest intergers to be drawn from the distribution, defaults to 0. :type low: int, optional :param high: One above the highest integer to be drawn from the distribution. :type high: int :param shape: shape of the output size, defaults to (1,). :type shape: tuple, optional :param dtype: data type of the output, defaults to "int32". :type dtype: str, optional Example:: >>> jt.randint(3, shape=(3, 3)) jt.Var([[2 0 2] [2 1 2] [2 0 1]], dtype=int32) >>> jt.randint(1, 3, shape=(3, 3)) jt.Var([[2 2 2] [1 1 2] [1 1 1]], dtype=int32) ''' if high is None: low, high = 0, low v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5) v = jt.floor_int(v) return v.astype(dtype)
[文档]def randint_like(x, low, high=None) -> Var: ''' samples random values from standard normal distribution with the same shape as x. :param x: reference variable. :type x: jt.Var :param low: lowest intergers to be drawn from the distribution, defaults to 0. :type low: int, optional :param high: One above the highest integer to be drawn from the distribution. :type high: int Example:: >>> x = jt.zeros((2, 3)) >>> jt.randint_like(x, 10) jt.Var([[9. 3. 4.] [4. 8. 5.]], dtype=float32) >>> jt.randint_like(x, 10, 20) jt.Var([[17. 11. 18.] [14. 17. 15.]], dtype=float32) ''' return randint(low, high, x.shape, x.dtype)
[文档]def normal(mean, std, size=None, dtype="float32") -> Var: ''' samples random values from a normal distribution. :param mean: means of the normal distributions. :type mean: int or jt.Var :param std: standard deviations of the normal distributions. :type std: int or jt.Var :param size: shape of the output size. if not specified, the shape of the output is determined by mean or std. Exception will be raised if mean and std are all integers or have different shape in this case. Defaults to None :type size: tuple, optional :param dtype: data type of the output, defaults to "float32". :type dtype: str, optional Example:: >>> jt.normal(5, 3, size=(2,3)) jt.Var([[ 8.070848 7.654219 10.252696 ] [ 6.383718 7.8817277 3.0786133]], dtype=float32) >>> mean = jt.randint(low=0, high=10, shape=(10,)) >>> jt.normal(mean, 0.1) jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083 8.972998 6.0674286 8.88026 ], dtype=float32) ''' if size is None: if isinstance(mean, Var) and isinstance(std, Var): assert mean.shape == std.shape size = mean.shape else: if isinstance(mean, Var): size = mean.shape if isinstance(std, Var): size = std.shape return jt.init.gauss(size, dtype, mean, std)
[文档]def attrs(var): return { "is_stop_fuse": var.is_stop_fuse(), "is_stop_grad": var.is_stop_grad(), "shape": var.shape, "dtype": var.dtype, }
Var.attrs = attrs
[文档]def fetch(*args): ''' Async fetch vars with function closure. Example 1:: for img,label in enumerate(your_dataset): pred = your_model(img) loss = critic(pred, label) acc = accuracy(pred, label) jt.fetch(acc, loss, lambda acc, loss: print(f"loss:{loss} acc:{acc}" ) Example 2:: for i,(img,label) in enumerate(your_dataset): pred = your_model(img) loss = critic(pred, label) acc = accuracy(pred, label) # variable i will be bind into function closure jt.fetch(i, acc, loss, lambda i, acc, loss: print(f"#{i}, loss:{loss} acc:{acc}" ) ''' assert len(args)>=1 func = args[-1] assert callable(func) args = list(args[:-1]) if len(args)>0 and isinstance(args[0], Sequence) \ and len(args[0])>=1 and isinstance(args[0][0], Var): raise TypeError("jt.Var should not inside a list or tuple.") var_map = [] variables = [] for i, v in enumerate(args): if isinstance(v, Var): variables.append(v) var_map.append(i) args[i] = None def callback(*results): for i,v in enumerate(results): args[var_map[i]] = v func(*args) core.ops.fetch(variables, callback)
Var.fetch = fetch
[文档]def display_memory_info(): import inspect, os f = inspect.currentframe() fileline = inspect.getframeinfo(f.f_back) fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}" core.display_memory_info(fileline)
[文档]def load(path: str): ''' loads an object from a file. ''' model_dict = safeunpickle(path) return model_dict
[文档]def save(params_dict, path: str): ''' saves the parameter dictionary to a file. :param params_dict: parameters to be saved :type params_dict: list or dictionary :param path: file path :type path: str ''' safepickle(params_dict, path)
def _uniq(x): a = set() b = [] for i in x: j = id(i) if j not in a: a.add(j) b.append(i) return b
[文档]class Module: def __init__(self, *args, **kw): pass
[文档] def execute(self, *args, **kw): ''' Executes the module computation. Raises NotImplementedError if the subclass does not override the method. ''' raise NotImplementedError("Please implement 'execute' method of "+str(type(self)))
def __call__(self, *args, **kw): return self.execute(*args, **kw) def __repr__(self): return self.__str__() def _get_name(self): return self.__class__.__name__ def __name__(self): pass
[文档] def dfs(self, parents, k, callback, callback_leave=None, recurse=True): ''' An utility function to traverse the module. ''' n_children = 0 for v in self.__dict__.values(): if isinstance(v, Module): n_children += 1 ret = callback(parents, k, self, n_children) if ret == False: return if recurse: for k,v in self.__dict__.items(): if not isinstance(v, Module): continue parents.append(self) v.dfs(parents, k, callback, callback_leave) parents.pop() if callback_leave: callback_leave(parents, k, self, n_children)
def __str__(self): ss = [] def callback(parents, k, v, n): # indent key:class_name(extra_repr) k = f"{k}: " if k is not None else "" s = f"{' '*(len(parents)*4)}{k}{v.__class__.__name__}" if n: s += '(' else: s += f"({v.extra_repr()})" ss.append(s) def callback_leave(parents, k, v, n): if n: ss.append(' '*(len(parents)*4)+')') self.dfs([], None, callback, callback_leave) return "\n".join(ss)
[文档] def parameters(self, recurse=True) -> List: ''' Returns a list of module parameters. ---------------- Example:: >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) >>> for p in net.parameters(): ... print(p.name) ... >>> for p in net.parameters(): ... print(p.name()) ... 0.weight 0.bias 2.weight 2.bias ''' ps = [] stack = [] def callback(parents, k, v, n): stack.append(str(k)) dc = v.__dict__ if isinstance(v, nn.ParameterList): dc = v.params for k2, p in dc.items(): if isinstance(k2, str) and k2.startswith("_"): continue if isinstance(p, Var): ps.append(p) pname = ".".join(stack[1:]+[str(k2)]) if len(pname) > len(p.name()): p.name(pname) def callback_leave(parents, k, v, n): stack.pop() self.dfs([], None, callback, callback_leave, recurse) return _uniq(ps)
[文档] def state_dict(self, to=None, recurse=True): ''' Returns a dictionary containing Jittor Var of the module and its descendants. Args: to: target type of var, canbe None or 'numpy' or 'torch' Return: dictionary of module's states. Example:: import jittor as jt from jittor.models import resnet50 jittor_model = resnet50() dict = jittor_model.state_dict() jittor_model.load_state_dict(dict) Example2(export Jittor params to PyTorch):: import jittor as jt from jittor.models import resnet50 jittor_model = resnet50() import torch from torchvision.models import resnet50 torch_model = resnet50() torch_model.load_state_dict(jittor_model.state_dict(to="torch")) ''' uniq_set = set() ps = {} stack = [] def callback(parents, k, v, n): stack.append(str(k)) dc = v.__dict__ if isinstance(v, nn.ParameterList): dc = v.params for k2, p in dc.items(): if isinstance(k2, str) and k2.startswith("_"): continue if isinstance(p, Var): if id(p) in uniq_set: continue if not getattr(p, "persistent", True): continue uniq_set.add(id(p)) pname = ".".join(stack[1:]+[str(k2)]) ps[pname] = p if len(pname) > len(p.name()): p.name(pname) def callback_leave(parents, k, v, n): stack.pop() self.dfs([], None, callback, callback_leave, recurse) if to == "numpy": for k,v in ps.items(): if isinstance(v, Var): ps[k] = v.numpy() elif to == "torch": import torch for k,v in ps.items(): if isinstance(v, Var): ps[k] = torch.Tensor(v.numpy()) return ps
[文档] def named_parameters(self, recurse=True) -> List[Tuple[str, Var]]: ''' Returns a list of module parameters and their names. ---------------- Example:: >>> net = nn.Linear(2, 5) >>> net.named_parameters() [('weight', jt.Var([[ 0.5964666 -0.3175258 ] [ 0.41493994 -0.66982657] [-0.32677156 0.49614117] [-0.24102807 -0.08656466] [ 0.15868133 -0.12468725]], dtype=float32)), ('bias', jt.Var([-0.38282675 0.36271113 -0.7063226 0.02899247 0.52210844], dtype=float32))] ''' state_dict = self.state_dict(recurse=recurse) return list(state_dict.items())
[文档] def load_state_dict(self, params) -> None: ''' Loads the module's parameters from a dictionary. ''' self.load_parameters(params)
def _load_from_state_dict(self, state, prefix="", *args, **kw): if len(prefix): new_state = {} for k,v in state.items(): if k.startswith(prefix): new_state[k[len(prefix):]] = v state = new_state self.load_state_dict(state)
[文档] def cuda(self, device=None): flags.use_cuda = 1 return self
[文档] def npu(self, device=None): flags.use_cuda = 1 return self
[文档] def modules(self) -> List: ''' Returns a list of sub-modules in the module recursively. ---------------- Example:: >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) >>> net.modules() [Sequential( 0: Linear(2, 10, float32[10,], None) 1: relu() 2: Linear(10, 2, float32[2,], None) ), Linear(2, 10, float32[10,], None), relu(), Linear(10, 2, float32[2,], None)] ''' ms = [] def callback(parents, k, v, n): if isinstance(v, Module): ms.append(v) self.dfs([], None, callback, None) return _uniq(ms)
[文档] def named_modules(self): ''' Returns a list of sub-modules and their names recursively. ---------------- Example:: >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) >>> net.named_modules() [('', Sequential( 0: Linear(2, 10, float32[10,], None) 1: relu() 2: Linear(10, 2, float32[2,], None) )), ('0', Linear(2, 10, float32[10,], None)), ('1', relu()), ('2', Linear(10, 2, float32[2,], None))] ''' ms = [] stack = [] def callback(parents, k, v, n): if isinstance(v, Module): stack.append(str(k)) name = ".".join(stack[1:]) ms.append((name, v)) def callback_leave(parents, k, v, n): stack.pop() self.dfs([], "", callback, callback_leave) return ms
[文档] def add_module(self, name, module): setattr(self, name ,module) return self
@property def _modules(self): return { k:v for k,v in self.__dict__.items() if isinstance(v, Module) } @property def _parameters(self): return { k:v for k,v in self.__dict__.items() if isinstance(v, Var) }
[文档] def requires_grad_(self, requires_grad=True): ''' Sets requires_grad for all parameters and sub-modules. ''' self._requires_grad = requires_grad self._place_hooker() return self
def __hooked_call__(self, *args, **kw): if hasattr(self, "__fhook2__"): if len(kw): self.__fhook2__(self, args, kw) else: self.__fhook2__(self, args) if hasattr(self, "__bihook__"): if len(kw): LOG.w("backward hook not support kw") args = grad_hooker(args, self.__bihook__) if hasattr(self, "_requires_grad") and not self._requires_grad: with jt.no_grad(): ret = self.__hooked_call__(*args, **kw) else: ret = self.__hooked_call__(*args, **kw) if hasattr(self, "__bohook__"): if len(kw): LOG.w("backward hook not support kw") if isinstance(ret, Var): ret = grad_hooker((ret,), self.__bohook__)[0] else: ret = grad_hooker(ret, self.__bohook__) if hasattr(self, "__fhook__"): if len(kw): self.__fhook__(self, args, ret, kw) else: self.__fhook__(self, args, ret) return ret def _place_hooker(self): cls = self.__class__ if hasattr(cls, "__hooked__"): return cls.__hooked__ = True cls.__call__, cls.__hooked_call__ = \ cls.__hooked_call__, cls.__call__
[文档] def register_forward_hook(self, func): ''' Register a forward function hook that will be called after Module.execute. The hook function will be called with the following arguments:: hook(module, input_args, output) or:: hook(module, input_args, output, input_kwargs) ''' self.__fhook__ = func self._place_hooker()
[文档] def remove_forward_hook(self): ''' Removes the current forward hook. ''' if hasattr(self,"__fhook__"): delattr(self,"__fhook__")
[文档] def register_pre_forward_hook(self, func): ''' Register a forward function hook that will be called before Module.execute. The hook function will be called with the following arguments:: hook(module, input_args) or:: hook(module, input_args, input_kwargs) ''' self.__fhook2__ = func self._place_hooker()
[文档] def remove_pre_forward_hook(self): ''' Removes the current pre-forward hook. ''' if hasattr(self,"__fhook2__"): delattr(self,"__fhook2__")
[文档] def register_input_backward_hook(self, func): self.__bihook__ = func self._place_hooker()
[文档] def remove_input_backward_hook(self): if hasattr(self,"__bihook__"): delattr(self,"__bihook__")
[文档] def register_output_backward_hook(self, func): self.__bohook__ = func self._place_hooker()
[文档] def remove_output_backward_hook(self): if hasattr(self,"__bohook__"): delattr(self,"__bohook__")
[文档] def register_backward_hook(self, func): ''' hook both input and output on backpropergation of this module. Arguments of hook are defined as:: hook(module, grad_input:tuple(jt.Var), grad_output:tuple(jt.Var)) -> tuple(jt.Var) or None `grad_input` is the origin gradients of input of this module, `grad_input` is the gradients of output of this module, return value is used to replace the gradient of input. ''' _grad_output = None def bohook(grad_output): nonlocal _grad_output _grad_output = grad_output def bihook(grad_input): return func(self, grad_input, _grad_output) self.register_input_backward_hook(bihook) self.register_output_backward_hook(bohook)
[文档] def remove_backward_hook(self): ''' Removes the backward input and output hooks. ''' self.remove_input_backward_hook() self.remove_output_backward_hook()
[文档] def children(self) -> List: ''' Returns an List of the children modules. ''' cd = [] def callback(parents, k, v, n): if len(parents) == 1 and isinstance(v, Module): cd.append(v) return False self.dfs([], None, callback, None) return cd
[文档] def extra_repr(self): ss = [] n = len(self.__init__.__code__.co_varnames) if self.__init__.__defaults__ is not None: n -= len(self.__init__.__defaults__) for i, k in enumerate(self.__init__.__code__.co_varnames[1:]): v = getattr(self, k) if hasattr(self, k) else None if isinstance(v, Var): v = v.peek() s = f"{k}={v}" if i >= n else str(v) ss.append(s) return ", ".join(ss)
[文档] def apply(self, func): ''' Applies a function to all sub-modules recursively. ''' for m in self.modules(): func(m)
[文档] def load_parameters(self, params): ''' loads parameters to the Module. :param params: dictionary of parameter names and parameters. ''' n_failed = 0 for key in params.keys(): v = self key_ = key.split('.') end = 0 for k in key_: if isinstance(v, nn.Sequential): if (k in v.layers): v = v[k] elif k.isdigit() and (ori_int(k) in v.layers): v = v[ori_int(k)] else: end=1 break else: if hasattr(v, k): v = getattr(v, k) assert isinstance(v, (Module, Var)), \ f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}" else: end = 1 break if end == 1: if not key.endswith("num_batches_tracked"): n_failed += 1 LOG.w(f'load parameter {key} failed ...') else: assert isinstance(v, Var), \ f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}" if isinstance(params[key], np.ndarray) or isinstance(params[key], list): param = array(params[key]) elif isinstance(params[key], Var): param = params[key] else: # assume is pytorch tensor param = array(params[key].cpu().detach().numpy()) if param.shape == v.shape: LOG.v(f'load parameter {key} success ...') v.update(param) v.sync(False, False) else: n_failed += 1 LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') if n_failed: LOG.w(f"load total {len(params)} params, {n_failed} failed")
[文档] def save(self, path: str): ''' saves parameters to a file. :param path: path to save. :type path: str Example:: >>> class Net(nn.Module): >>> ... >>> net = Net() >>> net.save('net.pkl') >>> net.load('net.pkl') ''' params = self.state_dict() safepickle(params, path)
[文档] def load(self, path: str): ''' loads parameters from a file. :param path: path to load. :type path: str Example:: >>> class Net(nn.Module): >>> ... >>> net = Net() >>> net.save('net.pkl') >>> net.load('net.pkl') This method also supports loading a state dict from a pytorch .pth file. .. note:: 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed ''' self.load_parameters(load(path))
[文档] def eval(self): ''' Sets the module in evaluation mode. ''' def callback(parents, k, v, n): if isinstance(v, Module): v.is_train = False self.dfs([], None, callback, None) # backup stop grad or not if not hasattr(self, "backup_grad_state"): self.backup_grad_state = {} for p in self.parameters(): if id(p) not in self.backup_grad_state: self.backup_grad_state[id(p)] = not p.is_stop_grad() p.stop_grad() return self
[文档] def train(self): ''' Sets the module in training mode. ''' def callback(parents, k, v, n): if isinstance(v, Module): v.is_train = True self.dfs([], None, callback, None) # backup stop grad or not if hasattr(self, "backup_grad_state"): for p in self.parameters(): if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]: p.start_grad() return self
[文档] def is_training(self) -> bool: ''' Returns whether the module is in training mode.''' if not hasattr(self, "is_train"): self.is_train = True return self.is_train
@property def training(self): if not hasattr(self, "is_train"): self.is_train = True return self.is_train @training.setter def training(self, value): self.is_train = value
[文档] def mpi_param_broadcast(self, root=0): if not in_mpi: return for p in self.parameters(): p.update(p.mpi_broadcast(root))
def __setattr__(self, key, value): object.__setattr__(self, key, value) def __getattr__(self, key): return object.__getattribute__(self, key)
[文档] def register_buffer(self, key, value, persistent=True): value.persistent = persistent object.__setattr__(self, key, value) return value
@property def _buffers(self): buffers = {} for k,v in self.__dict__.items(): if isinstance(v, jt.Var): buffers[k] = v return buffers
[文档] def named_buffers(self,recurse=False): buffers = [] for k,v in self.__dict__.items(): if isinstance(v, jt.Var): buffers.append((k,v)) return buffers
[文档] def named_children(self,): childs = [] for k,v in self.__dict__.items(): if isinstance(v,Module): childs.append((k,v)) return childs
[文档] def float64(self): '''convert all parameters to float16''' self._amp_level = 0 for p in self.parameters(): if p.dtype.is_float(): p.assign(p.float64()) return self
[文档] def float32(self): '''convert all parameters to float32''' self._amp_level = 0 for p in self.parameters(): if p.dtype.is_float(): p.assign(p.float32()) return self
[文档] def float16(self): '''convert all parameters to float16''' # self._amp_level = 3 if flags.th_mode else 4 # amp level better set globally self._amp_level = -1 if self._amp_level >= 0: cls = self.__class__ cls.__call__ = cls.__half_call__ for p in self.parameters(): if p.dtype.is_float(): p.assign(p.float16()) return self
[文档] def bfloat16(self): '''convert all parameters to bfloat16''' # self._amp_level = 3 if flags.th_mode else 4 # amp level better set globally self._amp_level = -1 if self._amp_level >= 0: cls = self.__class__ cls.__call__ = cls.__half_call__ for p in self.parameters(): if p.dtype.is_float(): p.assign(p.bfloat16()) return self
def __half_call__(self, *args, **kw): amp_level = getattr(self, "_amp_level", -1) if amp_level >= 0: with flag_scope(amp_level=amp_level): return self.execute(*args, **kw) else: return self.execute(*args, **kw)
[文档] def half(self): '''convert all parameters to float16''' return self.float16()
[文档] def float_auto(self): '''convert all parameters to float16 or float32 automatically by jt.flags.auto_mixed_precision_level and jt.flags.amp_reg''' self._amp_level = -1 for p in self.parameters(): if p.dtype.is_float(): p.assign(p.float_auto()) return self
[文档]class Function(Module): ''' Function Module for customized backward operations Example 1 (Function can have multiple input and multiple output, and user can store value for backward computation):: import jittor as jt from jittor import Function class MyFunc(Function): def execute(self, x, y): self.x = x self.y = y return x*y, x/y def grad(self, grad0, grad1): return grad0 * self.y, grad1 * self.x a = jt.array(3.0) b = jt.array(4.0) func = MyFunc.apply c,d = func(a, b) da, db = jt.grad(c+d*3, [a, b]) assert da.data == 4 assert db.data == 9 Example 2(Function can return None for no gradiant, and gradiant can also be None):: import jittor as jt from jittor import Function class MyFunc(Function): def execute(self, x, y): self.x = x self.y = y return x*y, x/y def grad(self, grad0, grad1): assert grad1 is None return grad0 * self.y, None a = jt.array(3.0) b = jt.array(4.0) func = MyFunc.apply c,d = func(a, b) d.stop_grad() da, db = jt.grad(c+d*3, [a, b]) assert da.data == 4 assert db.data == 0 ''' def __call__(self, *args): if flags.no_grad: return self.execute(*args) backup = args args = list(args) taped_inputs = [] taped_outputs = [] input_mask = [-1] * len(args) for i,v in enumerate(args): if isinstance(v, Var): if v.is_stop_grad(): # -2 in input_mask represents it is stop_grad input_mask[i] = -2 continue v = v.tape() input_mask[i] = len(taped_inputs) args[i] = v taped_inputs.append(v) ori_res = self.execute(*args) if not isinstance(ori_res, Sequence): res = [ori_res] else: res = list(ori_res) output_mask = [-1] * len(res) for i,v in enumerate(res): if isinstance(v, Var): v = v.tape() output_mask[i] = len(taped_outputs) res[i] = v taped_outputs.append(v) self.input_mask = input_mask self.output_mask = output_mask # tape output and input together so # backward treat them as one operator tape_together(taped_inputs, taped_outputs, self._grad) if isinstance(ori_res, Sequence): return res else: return res[0] def _grad(self, *args): new_args = ( (args[i] if i>=0 else None) for i in self.output_mask ) ret = self.grad(*new_args) if not isinstance(ret, Sequence): ret = (ret,) new_ret = [] for i, r in enumerate(ret): j = self.input_mask[i] if j<0: # -2 in input_mask represents it is stop_grad assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ "because the input value is not jittor variable." else: new_ret.append(r) return new_ret
[文档] def dfs(self, parents, k, callback, callback_leave=None, recurse=True): pass
[文档] @classmethod def apply(cls, *args, **kw): func = cls() return func(*args, **kw)
[文档]class GradHooker(Function): def __init__(self, hook): self.hook = hook
[文档] def execute(self, *args): return args
[文档] def grad(self, *grad_input): ret = self.hook(grad_input) if ret: grad_input = ret return grad_input
[文档]def grad_hooker(args, hook): hooker = GradHooker(hook) return hooker(*args)
[文档]def register_hook(v, hook): """ register hook of any jittor Variables, if hook return not None, the gradient of this variable will be alter, Example:: x = jt.array([0.0, 0.0]) y = x * [1,2] y.register_hook(lambda g: g*2) dx = jt.grad(y, x) print(dx) # will be [2, 4] """ def _hook(grads): g = hook(grads[0]) if g is not None: return (g,) return None hooker = GradHooker(_hook) v.swap(hooker(v)[0]) return v
Var.register_hook = register_hook
[文档]def make_module(func, exec_n_args=1): class MakeModule(Module): def __init__(self, *args, **kw): self.args = args self.kw = kw def execute(self, *args): return func(*args, *self.args, **self.kw) def __str__(self): return f"{func.__name__}({self.extra_repr()})" def extra_repr(self): return ",".join(map(str, self.args)) MakeModule.__name__ = func.__name__ return MakeModule
[文档]def dirty_fix_pytorch_runtime_error(): ''' This funtion should be called before pytorch. Example:: import jittor as jt jt.dirty_fix_pytorch_runtime_error() import torch ''' import os, platform if platform.system() == 'Linux': os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND import jittor_utils with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): import torch
import atexit
[文档]class ExitHooks(object): def __init__(self): self.exit_code = None self.exception = None
[文档] def hook(self): self._orig_exit = sys.exit sys.exit = self.exit sys.excepthook = self.exc_handler
[文档] def exit(self, code=0): self.exit_code = code self._orig_exit(code)
[文档] def exc_handler(self, exc_type, exc, *args): self.exception = exc traceback.print_exception(exc_type, exc, *args)
hooks = ExitHooks() hooks.hook()
[文档]def jittor_exit(): if hooks.exit_code is not None: pass elif hooks.exception is not None: pass else: pass # core.sync_all(True) core.cleanup()
atexit.register(jittor_exit)
[文档]def vtos(v): data_str = f"jt.Var({v.numpy()}, dtype={v.dtype})" data_str = data_str.replace("\n", "\n ") return data_str
Var.__str__ = vtos Var.__repr__ = vtos Var.peek = lambda x: f"{x.dtype}{x.shape}"
[文档]def size(v, dim=None): if dim is None: return v.shape return v.shape[dim]
Var.size = size
[文档]def to_int(v): return ori_int(v.item())
[文档]def to_float(v): return ori_float(v.item())
[文档]def to_bool(v): assert v.dtype.is_int() or v.dtype.is_bool() return ori_bool(v.item())
Var.__int__ = to_int Var.__float__ = to_float Var.__bool__ = to_bool
[文档]def format(v, spec): return v.item().__format__(spec)
Var.__format__ = format
[文档]def get_len(var): return var.shape[0]
Var.__len__ = get_len int = int32 Var.int = Var.int32 Var.long = Var.int32 float = float32 Var.float = Var.float32 double = float64 Var.double = Var.float64 half = float16 Var.half = Var.float16
[文档]def is_var(v): return isinstance(v, Var)
# __array__ interface is used for np.array(jt_var) Var.__array__ = Var.numpy Var.__array_priority__ = 2000 # __reduce__, __module__ is used for pickle.dump and pickle.load Var.__module__ = "jittor" Var.__reduce__ = lambda self: (Var, (self.data,)) from . import nn from . import attention from . import lr_scheduler from . import linalg from .linalg import einsum from .nn import matmul, \ bmm, bmm_transpose, \ baddbmm from . import contrib from . import numpy2cupy from .contrib import concat, cat from .misc import * from . import sparse from . import optim from . import dataset from . import init dtype = NanoString import jittor_utils for backend in jittor_utils.backends: if hasattr(backend, "post_process"): backend.post_process() # impl x.func(...) -> func_(...) args = {"x", "input", "self"} _white_list = {"mul", "add", "sub"} for k,v in list(Var.__dict__.items()): if k.startswith("_"): continue if k.endswith("_"): continue if not callable(v): continue if k not in _white_list: if not hasattr(v, "__code__"): continue conames = v.__code__.co_varnames if len(conames) == 0: continue arg_name = conames[0] if arg_name not in args: continue new_k = k+"_" if hasattr(Var, new_k): continue
[文档] def inplace_wrapper(new_k, prev_func): setattr(Var, new_k, lambda x, *args, **kw: x.assign(prev_func(x, *args, **kw)))
inplace_wrapper(new_k, v) from . import math_util from .math_util import * from . import distributions