0: prefix_tokens, embeddings, known_servers
1: generated_sequence = list()
2: cache = dictionary()
3: streams = dictionary()
4: chain = find_best_chain(known_servers)
6: streams[server] = rpc_inference(server)
7: cache[server] = list()
8: end for
9:
10: inputs = embeddings(prefix_tokens)
11: while should_continue(generated_sequence) do
12: tail_servers = copy(chain)
13: while not empty(tail_servers) do
14: server = tail_servers.pop_left()
15: try:
17: outputs = streams[server].send(inputs)
18: cache[server].append(inputs)
19: inputs = outputs
20: catch ServerFailed:
22: streams.pop(server).close()
23: past_inputs = cache.pop(server)
24: new_servers = replace_failed_server(
25: server, past_inputs, cache,
26: streams, known_servers)
27: chain.replace(server, new_servers)
28: tail_servers.push_left(new_servers)
29: end while
30:
31: logits = compute_logits(outputs, embeddings)
32: next_token = choose_next(logits) {e.g. greedy}
33: generated_sequence.append(next_token)
34: inputs = embeddings(next_token)
35: end while
36:
38: streams[server].close()
39: end for
40: return generated_sequence