flattenDynamicTensor function
Flattens an arbitrarily nested tensor to a flat Float32List.
Implementation
Float32List flattenDynamicTensor(Object? out) {
if (out == null) {
throw TypeError();
}
// Fast paths for common concrete shapes produced by allocTensorShape.
if (out is List<double>) {
final int n = out.length;
final Float32List result = Float32List(n);
for (int i = 0; i < n; i++) {
result[i] = out[i];
}
return result;
}
if (out is List<List<double>>) {
int total = 0;
for (int i = 0; i < out.length; i++) {
total += out[i].length;
}
final Float32List result = Float32List(total);
int w = 0;
for (int i = 0; i < out.length; i++) {
final List<double> row = out[i];
for (int j = 0; j < row.length; j++) {
result[w++] = row[j];
}
}
return result;
}
if (out is List<List<List<double>>>) {
int total = 0;
for (int i = 0; i < out.length; i++) {
final List<List<double>> plane = out[i];
for (int j = 0; j < plane.length; j++) {
total += plane[j].length;
}
}
final Float32List result = Float32List(total);
int w = 0;
for (int i = 0; i < out.length; i++) {
final List<List<double>> plane = out[i];
for (int j = 0; j < plane.length; j++) {
final List<double> row = plane[j];
for (int k = 0; k < row.length; k++) {
result[w++] = row[k];
}
}
}
return result;
}
if (out is List<List<List<List<double>>>>) {
int total = 0;
for (int i = 0; i < out.length; i++) {
final List<List<List<double>>> cube = out[i];
for (int j = 0; j < cube.length; j++) {
final List<List<double>> plane = cube[j];
for (int k = 0; k < plane.length; k++) {
total += plane[k].length;
}
}
}
final Float32List result = Float32List(total);
int w = 0;
for (int i = 0; i < out.length; i++) {
final List<List<List<double>>> cube = out[i];
for (int j = 0; j < cube.length; j++) {
final List<List<double>> plane = cube[j];
for (int k = 0; k < plane.length; k++) {
final List<double> row = plane[k];
for (int l = 0; l < row.length; l++) {
result[w++] = row[l];
}
}
}
}
return result;
}
// Fallback: recursive walk for unknown shapes.
final List<double> flat = <double>[];
void walk(dynamic x) {
if (x is num) {
flat.add(x.toDouble());
} else if (x is List) {
for (final e in x) {
walk(e);
}
} else {
throw StateError('Unexpected output element type: ${x.runtimeType}');
}
}
walk(out);
return Float32List.fromList(flat);
}