flattenDynamicTensor function

Float32List flattenDynamicTensor(
  1. Object? out
)

Flattens an arbitrarily nested tensor to a flat Float32List.

Implementation

Float32List flattenDynamicTensor(Object? out) {
  if (out == null) {
    throw TypeError();
  }

  // Fast paths for common concrete shapes produced by allocTensorShape.
  if (out is List<double>) {
    final int n = out.length;
    final Float32List result = Float32List(n);
    for (int i = 0; i < n; i++) {
      result[i] = out[i];
    }
    return result;
  }
  if (out is List<List<double>>) {
    int total = 0;
    for (int i = 0; i < out.length; i++) {
      total += out[i].length;
    }
    final Float32List result = Float32List(total);
    int w = 0;
    for (int i = 0; i < out.length; i++) {
      final List<double> row = out[i];
      for (int j = 0; j < row.length; j++) {
        result[w++] = row[j];
      }
    }
    return result;
  }
  if (out is List<List<List<double>>>) {
    int total = 0;
    for (int i = 0; i < out.length; i++) {
      final List<List<double>> plane = out[i];
      for (int j = 0; j < plane.length; j++) {
        total += plane[j].length;
      }
    }
    final Float32List result = Float32List(total);
    int w = 0;
    for (int i = 0; i < out.length; i++) {
      final List<List<double>> plane = out[i];
      for (int j = 0; j < plane.length; j++) {
        final List<double> row = plane[j];
        for (int k = 0; k < row.length; k++) {
          result[w++] = row[k];
        }
      }
    }
    return result;
  }
  if (out is List<List<List<List<double>>>>) {
    int total = 0;
    for (int i = 0; i < out.length; i++) {
      final List<List<List<double>>> cube = out[i];
      for (int j = 0; j < cube.length; j++) {
        final List<List<double>> plane = cube[j];
        for (int k = 0; k < plane.length; k++) {
          total += plane[k].length;
        }
      }
    }
    final Float32List result = Float32List(total);
    int w = 0;
    for (int i = 0; i < out.length; i++) {
      final List<List<List<double>>> cube = out[i];
      for (int j = 0; j < cube.length; j++) {
        final List<List<double>> plane = cube[j];
        for (int k = 0; k < plane.length; k++) {
          final List<double> row = plane[k];
          for (int l = 0; l < row.length; l++) {
            result[w++] = row[l];
          }
        }
      }
    }
    return result;
  }

  // Fallback: recursive walk for unknown shapes.
  final List<double> flat = <double>[];
  void walk(dynamic x) {
    if (x is num) {
      flat.add(x.toDouble());
    } else if (x is List) {
      for (final e in x) {
        walk(e);
      }
    } else {
      throw StateError('Unexpected output element type: ${x.runtimeType}');
    }
  }

  walk(out);
  return Float32List.fromList(flat);
}