Data Provider
The data provider is the interface that will be implemented by the different structures responsible for parsing the training data and creating samples for the clients to use when training a model. There's three different structures ways to declare a data provider depending on how they will host the data and how the clients will request them:
- Local data provider where the client already contains the data to be used on training.
- HTTP data provider the clients will recieve a list of urls to make request on where all the necessary data of training is hosted.
- TCP data provider a local server separated from the client that the clients could communicate using TCP to get the training data.
Overview
The data provider acts as a server that can be accessed via TCP by clients to obtain the data they need for training.
When a client starts a round of training, it receives an ID or a range of IDs from the coordinator, representing all the batches that will be used for that round. Each batch contains a specific range of the overall data. The client can then call the data provider with the assigned IDs for the run and fetch the corresponding data to begin training.
To better understand how the data is partitioned for each client, refer to the following diagram:
flowchart TD C((Coordinator)) C1[Client] C2[Client] C --Batch IDs--> C1 C --Batch IDs--> C2 subgraph Data Provider B1["Batch 1. Data 2. Data 3. Data "] B2["Batch 1. Data 2. Data 3. Data "] B3["Batch 1. Data 2. Data 3. Data "] B4["Batch 1. Data 2. Data 3. Data "] B5["Batch 1. Data 2. Data 3. Data "] B6["Batch 1. Data 2. Data 3. Data "] B4 ~~~ B1 B5 ~~~ B2 B6 ~~~ B3 end B1 --> C1 B2 --> C1 B3 --> C1 B4 --> C2 B5 --> C2 B6 --> C2
The number of batches used for training in a run, as well as the indexes of data that each batch contains, can be configured.
Deep Dive
For the coordinator's initial state, the state.toml
file contains configuration details for the entire run. A key section to consider is [model.LLM.data_location]
, which specifies whether the data will be hosted on a TCP server, accessed via HTTP, or stored in a local folder.
When loading a model, the required configuration depends on the data provider implementation being used:
-
TCP Server:
- If the data provider is configured as a TCP server, an additional file named
data.toml
is required. - This file contains configurations for local training, including:
- Data location
- Token size
- Sequence length
- A seed to shuffle the data if necessary
- An example
data.toml
file can be found inpsyche/config
within the various initial state examples.
- If the data provider is configured as a TCP server, an additional file named
-
HTTP Provider:
- For the HTTP data provider, no additional configuration file is needed.
- The required fields for this setup include:
- The URL (or a set of URLs) from which the data will be fetched
- Token size (in bytes)
- A shuffle seed, if data shuffling is desired.
-
Client Hosting the Data:
- In this case, the client must simply provide the URL where the data is hosted.
The init_run
function initializes the data provider using the configuration
and creates a DataFetcher
, the structure responsible for managing the data
fetching process. The data fetcher is part of the TrainingStepMetadata
, which
holds the internal data for the training step within the StepStateMachine
,
along with other metadata—one for each step.
Once the data provider is created and included in the state machine, it will be
used at the start of the epoch and during every training step. The client
monitors changes in the coordinator's state, and upon detecting a step
transition, it calls the apply_state
function for the RunManager
. This, in
turn, calls the apply_state
function for the StepStateMachine
. If the state
indicates that a training round is starting, the start
function for the
TrainingStepMetadata
is invoked.
The start
function initiates the actual training process on the client side.
Its first task is to fetch the data required for training using the
assign_data_for_state
function. This function determines the number of
batches for the round and the indices of data within each batch. The client is
then assigned an interval of batch IDs, called data_assignments
, which it
fetches from the data provider using the fetch_data
function of the
DataFetcher
.
The fetch_data
function parses the batch IDs using the data indices per batch
to calculate the actual intervals of data to use. It creates a channel to send
and receive batches. Once the data intervals are calculated, the client calls
the get_samples
function on the data provider to retrieve the raw data for
those IDs. This process repeats in a loop until all batch IDs are requested and
sent through the channel.
On the other end, the receiver is used in the train
function. It continuously
receives data from the channel and uses it for training until all data is
consumed.