workOnTask method

  1. @override
Future<void> workOnTask(
  1. TaskExecutionContext context
)
override

Implementation

@override
Future<void> workOnTask(TaskExecutionContext context) async {
  if (factor <= 1) {
    context.fail("Distillation factor must be greater than 1");
    return;
  }

  for (
    int i = (lastChunk != null ? (lastChunk! + 1) : 0);
    i < size;
    i += factor
  ) {
    if (!context.hasTimeRemaining) {
      return;
    }

    List<int> chunkIndexes = List.generate(
      factor,
      (j) => i + j,
    ).where((j) => j < size).toList();

    if (chunkIndexes.isEmpty) continue;

    List<String> chunkIds = chunkIndexes
        .map((i) => "${recordLocation.hashId}.L$lod.$i")
        .toList();
    List<Chunk> chunks = await Future.wait(chunkIds.map(get));
    int charStart = chunks.first.charStart;
    int charEnd = chunks.last.charEnd;
    int index = nextOutputIndex;
    String prompt = this.prompt(lod > 0);
    String input = chunks
        .mapIndexed(
          (chunk, index) =>
              index == chunks.length - 1 ? chunk.fullContent : chunk.content,
        )
        .join("\n\n")
        .trim();

    try {
      ChatResult result = await FireRag.instance.llm.connector.call(
        ChatRequest(
          messages: [Message.system(prompt), Message.user(input)],
          model: FireRag.instance.llm.model,
        ),
      );
      emitted = index + 1;
      lastChunk = chunkIndexes.last;
      String id = "${recordLocation.hashId}.L${lod + 1}.$index";
      Chunk distilled = Chunk(
        index: index,
        content: result.message.content.toString().trim(),
        metadata: chunks.first.metadata,
        charStart: charStart,
        charEnd: charEnd,
        lod: lod + 1,
        down: chunks.map((i) => i.index).toList(),
      );
      toEmbed ??= [];
      toEmbed!.add(id);
      List<String> embed = (toEmbed?.length ?? 0) >= embedBatchSize
          ? toEmbed?.take(embedBatchSize).toList() ?? []
          : [];
      if (embed.isNotEmpty) {
        toEmbed?.removeRange(0, embed.length);
      }
      await Future.wait([
        FirestoreDatabase.instance
            .collection(collection)
            .doc(id)
            .set(distilled.toMap()),
        ...chunkIds.map(
          (i) => FirestoreDatabase.instance
              .collection(collection)
              .doc(i)
              .update({"up": index}),
        ),
        if (embed.isNotEmpty)
          FireRag.instance.taskManager.schedule(
            TaskEmbed(
              lod: lod + 1,
              recordLocation: recordLocation,
              bucket: bucket,
              taskId: "$taskId.iembed.$index",
              collection: collection,
              chunks: embed,
            ),
          ),
      ]);
    } catch (e, es) {
      if (e.toString().contains("429")) {
        context.backoff();
        return;
      }

      context.fail(e.toString());
      error(es);
    }
  }

  List<List<String>> to = (toEmbed?.isNotEmpty ?? false)
      ? await Stream.fromIterable(toEmbed!)
            .accumulateBy(embedBatchSize, (w) => 1, maxAmount: embedBatchSize)
            .toList()
      : [];

  await Future.wait([
    if (to.isNotEmpty)
      ...to.mapIndexed(
        (batch, index) => FireRag.instance.taskManager.schedule(
          TaskEmbed(
            lod: lod + 1,
            recordLocation: recordLocation,
            bucket: bucket,
            taskId: "$taskId.embed.$index",
            collection: collection,
            chunks: batch,
          ),
        ),
      ),
    if (shouldScheduleNextLevel) ...[
      FireRag.instance.pushRecordProgress(
        location: recordLocation,
        deltaDistilledTotal: (emitted! / factor.toDouble()).ceil(),
      ),
      FireRag.instance.taskManager.schedule(
        TaskDistill(
          embedBatchSize: embedBatchSize,
          taskId: "$taskId.distill.L${lod + 1}",
          recordLocation: recordLocation,
          lod: lod + 1,
          collection: collection,
          bucket: bucket,
          targetOutputSize: targetOutputSize,
          factor: factor,
          size: emitted!,
        ),
      ),
    ],
  ]);

  context.complete();
}