Implementing Data Capture for ML Observability and Drift Detection // Pushkar Garg // DE4AI
Pushkar is a Machine Learning and Artificial Engineer working as a Team Lead at Clari in the San Francisco Bay Area. He has more than a decade of experience working in the field of Engineering. Pushkar's specialization lies around building ML Models and building Platforms for training and deploying models.
Modern ML Systems comprise of complex data pipelines and multiple transformations happening in multiple layers of the system like the Data Warehouse, Offline Feature Store, Online Feature Store etc. One important aspect of productionizing any ML Model is to implement ML Observability. The key component for enabling ML Observability is to have efficient data capture running on the prediction endpoints. In this talk, I will talk about my experience of implementing Data Capture by coding up an in-memory buffer and lessons learnt while doing so. I will also touch base on how downstream monitoring jobs consume these data capture logs to complete the loop on ML Observability.
Skylar [00:00:07]: Welcome, Pushkar.
Pushkar Garg [00:00:10]: Hey, everyone. How's it going?
Skylar [00:00:14]: Good.
Pushkar Garg [00:00:016]: All right, let's get started. Let me share my screen. All right, let's get started. So, hi everyone. My name is Pushkar Garg.
Pushkar Garg [00:00:28]: I work as a staff machine learning platform engineer at Clary and excited to be presenting this topic, which is a slightly niche area, but like more focused in terms of like if you talk about data quality, but I'm going to talk about data capture on model endpoints and how that helps in running data observability on like, the models deployed in, in production. So let's get started. So the agenda for today in my talk is going to be. So we'll, we'll go over the different, different data components that you may see in a data platform which powers all of the model building capabilities in the AI and ML platform. Then we'll go over the importance of monitoring data quality and then in general model monitoring as well. After that we'll move on to how do we measure, actually measure data quality with metrics? What are the metrics that we need? How do we build that infrastructure to identify and capture those metrics? So yeah, the next topic we're going to cover is on the managed data capture. So there are certain tools which we, which if we use to deploy our models, they provide managed data capture, meaning like automatic data capture. And what are some of the limitations related to that? And then what is the need for a custom, what may be the need for a custom data capture architecture? And then how do we go about implement, implementing that? And then it will all try to bring all of that together into like, how does that fit into the downstream systems and then some of the best practices and some lessons learned around like implementing custom data capture.
Pushkar Garg [00:02:25]: So let's start with the different data components. So as any company would already have in production their production databases, right? And then with, with the advent of generative AI, nowadays we are seeing a lot of the use cases for utilizing the data which is stored in silos, in, in like s, three buckets or like object storage, which can be like files, call transcripts, email bodies, anything like that. Right? So how do we bring all of that data together into one component which can be used to run transformations and then ultimately build features so that we can use those features to ultimately build models. So in this diagram, I've tried to separate out the data world and then the ML world so that it becomes easier to, you know, specify responsibilities and also implement like data quality checks. So in this case, it's almost possible to implement data quality checks at all of the checkpoints wherever data is flowing, and ultimately getting persisted. So we start with the production databases and then the object storage to bring together the structured and unstructured data into the data lake, for example. And then following the medallion architecture that has been popularized by databricks, we can transform that data over certain schemas and then move from like a bronze to silver to gold schema, which gives you the confidence that this data will always be ready for consuming into model building. So once that data has been transformed into the modeled layer or the gold layer, you can sync that into an offline feature store and use that to train the models.
Pushkar Garg [00:04:29]: So run training jobs on the offline feature store, develop models, and then ultimately deploy those to a serving endpoint which serves prediction requests from upstream services. So like any kind of model prediction that can be done, once the model is deployed to that serving endpoint, it can take care of that. And there may well also be a need to query real time features from an online feature store which ingests real time features from upstream services that may or may not go through transformation. So like your real time services may be creating data for a particular user that you want to do a prediction for. So it's necessary to get the most updated data that may go through transformation. So ultimately, all that comes together in the online feature store and then the serving endpoint can use that to get the latest data about a user or any product and then make a prediction. So the focus of the talk today will be on the serving endpoint, on the data quality on the serving endpoint. And how do we enable that? By enabling data capture in the first place.
Pushkar Garg [00:05:38]: So, yeah, excited about that. So the importance of monitoring, like, as all of you already know, right, models, prediction quality is as good as the data quality, right? So garbage in, garbage out. Whatever is fed to the model is like comparable to the prediction that is generated by the model. So the other importance of the, of monitoring in general and data quality in more to be more specific, is to identify training. Serving skew. When we train the model on, on a certain data set, it identifies the relationship between the independent variables and the dependent variable, and then it attaches a certain weight to all of those variables to come up with a certain equation, and then ultimately give you, for a certain set of features using those weights give you the output which is the dependent variable. Now, if the data in production is much different, if the distribution of the data in production is much different, then it was in training, you can expect there needs to be some adjustment to the weights that needs to be done for an accurate prediction. So to be able to identify that that has happened in the production data, meaning that there has been a distribution drift in the inference data, we need to be able to capture data and then ultimately identify that, and then retrain the models on the latest updated data to capture those changes.
Pushkar Garg [00:07:17]: The next type of identification that we need to do is concept drift. So the data in itself may not have a large distribution change, but the relationship between the data and the dependent, like the features and the dependent variable may change over time due to, like for example, Covid, right? So when Covid happened, a lot of the models which were in production would not have performed up to the mark because the overall situation has changed, and then the relationship between the features and the prediction has changed. So it's necessary, monitoring is necessary for identifying that use case as well. And then finally, like I said, the data quality checks can be implemented at all checkpoints of a data pipeline. But on the serving endpoint is sort of the last frontier where you can still identify issues in the data pipeline. So anything which may have been ignored or which may not have been captured in terms of failures or in terms of having breakages in the data pipeline can ultimately be identified in the serving endpoint if, if you decide to have monitoring there. So we've established the importance of data quality. Now let's see, how do we measure data quality? There are different types of metrics, and these are not the exhaustive list of metrics, but this is something that can be easily enabled on a serving endpoint to identify the usual data quality metrics.
Pushkar Garg [00:09:06]: So let's start with the data type check, right? So data type. So in here, what we're trying to do is we're trying to take the training data and compare it with the inference data to identify these checks. So what, what we are doing is we're looking at every column in the training data, comparing it with the corresponding column in the inference data, and then identifying these. So data type check is, is, has each column, if each column has the same data type as it was present in the training data completeness check, how many nulls are there as compared to the training data? Right? So if, if you're getting more nulls as part of your inference data set, that it, it means that there is some kind of broken, either like there is a broken pipeline or there has been some change which, which has not been accounted for. And then this issue needs to be fixed. So to be able to identify that that is happening is the first frontier and then move on from there to rectify it is the next. And then I think the most important one is the baseline drift check. So how do we measure that there has been a distribution change? So to be able to measure that, we look at each column and then compare the distribution of the data in the training data set to the inference dataset to see if there has been a distribution change in the data and, and trigger some alerts if there has been like a significant change.
Pushkar Garg [00:10:38]: Now there are built in algorithms that can be used. It's actually non trivial to implement it if you go the native way, if you try to implement it using an algorithm. But there are like libraries like dequeue which support, uh, identifying, uh, these, uh, uh, the, the drift in the data by just using the library directly, uh, and supplying it with like the different, uh, the two columns to compare. So this is something that can be used and it's, it's a, it's a good metric to identify. Uh, and then these last two are straightforward. Like if you, if there is any missing column in itself, or if there is like any extra column that you did not expect in the inference data, then that needs to be identified and rectified as well. Okay, so once, once we have established, how do we measure data quality by generating the metrics right now, how do we generate those metrics? So let's start with the training data here. So you have a training data set, you run a training job, build a model, deploy the model to an endpoint, and then that model endpoint serves prediction requests from the prediction service.
Pushkar Garg [00:11:54]: To be able to generate those metrics, we need the training data and we need the inference data so that can the inference. So we already have the training data and we can use data capture to capture the requests that are coming into the model endpoint to create that inference dataset. Once both of these are available, we can pass them on to a monitoring job, and that monitoring job can look at each column in the training data and inference data and then generate the metrics that we just discussed and then persist them to your choice of observability tool. Yeah, it's important to version training data because a model goes through multiple experiments, multiple a, b testing. So it's necessary to capture the current version of the model was deployed with which version of the training data in a model registry. And then the monitoring job can pull that data from that model registry and then load that from s three. And it is also possible to generate firsthand baseline and constraints using that training data and then only use those constraints to look at the inference data. And then for each column, right, let's say what is the range bound of the variables of the features that you see for a particular column? What was that compared to the training data set? So it's possible to optimize it even further.
Pushkar Garg [00:13:32]: But this is like an abstraction on that front. Okay, so now let's talk about data capture. So we identified that we can enable data capture on a model endpoint and then that can help us in capturing the inference data. So there are automatic ways of capturing data. So if you, for example, if you deploy a sage maker endpoint, your model code gets deployed on an EC two instance and it would automatically provide you a server, a flask server, which is the which, which gets all of the prediction requests in HTTP or GRPC, whatever you have it. And it will, it will forward the request to the model. Model will make, what model will use those features to make the prediction. It may also query the online features store like we discussed, and then it will make the prediction, return the request to the server, and then the server will persist the combination of the features and the prediction so as it received from the prediction service and the model.
Pushkar Garg [00:14:47]: So imagine if you're using protobuf as the object in the request to the server. So the server will persist that, and then the model will deserialize those requests to make the features in plain text. It will generate the prediction, serialize the request again, and then send it back to the server to send it back to the prediction service. So in here, the server will capture the requests in protoba format if that is the case. So that's how the managed data capture works. Now what is the need for coding custom data capture? So first of all, it doesn't work for real time features in this particular case, for example with Sagemaker. So whatever is sent to the server, that's what is captured in the inference data set for data capture if we go the managed way. So imagine the prediction service sending just a user id to the server and then the model querying the online feature store for the latest values or the latest features related to that particular user.
Pushkar Garg [00:16:01]: It would have not capture those requests or those features that the model got from the online features store and then it will just persist that user id and then the corresponding prediction. So there are ways to still recreate that data set by doing time travel queries. So ultimately what happens is whatever is, the data that comes in the online feature store gets synced in the offline feature store at a Ydez specified cadence. So what you can do is you can run another job which, which parses all of the data in, in the data capture logs and takes the request time, and then regenerates the set of features from the offline feature store by doing time travel queries that whatever is the feature set that the model would have used at that time to make the prediction. You can recreate that data set and then run the monitoring block to have like the full set of features to compare against the training data. And one thing that I identified was that it does not work for multi prediction models. So like, what I mean by multi prediction model is like for example, a ranking model where you would have to predict a set of requests. It's not a multi model endpoint per se, but it's one model, but it is serving multiple requests.
Pushkar Garg [00:17:27]: Ranking, for example, you would have multiple items that you would want to rank and then generate, generate a score for all of them and then rank them in order of user interactivity or whatever. So in that case, the persistence to s three for the data capture would involve a whole array of objects. And then at that time, Sagemaker was not able to handle that. So there was this another need there. And then, like I discussed earlier as well, if you have, if you're using protobuf as your request object, you need to keep track of the serialization and deserialization class to be able to deserialize the requests that are captured in s three from whenever you run the monitoring job to get a, uh, the data in plain text. So these are some of the limitations which, because of which we implemented custom data capture. Excuse me. Now let's look at custom data capture, right? So how do we, how did we go about implementing custom data capture? So the, you can implement a data capture buffer on the model itself.
Pushkar Garg [00:18:51]: So it will use the hardware of the model endpoint that it is running on, like in terms of memory and compute. But you can still code that within the model code. So what we did was we created a package, and that package has to be imported by any data scientist who is looking to implement this data capture. And what can be done is you instantiate that buffer and then you, after the, after the model has made the prediction. So let's say it queried something from the online feature store, it made that prediction, and then ultimately it's able to get everything in plain text. And then just before serializing the request to send back to the server, it can capture all of that data. It can add, append that data to the, to the buffer that it instantiated. So and then the buffer from time to time.
Pushkar Garg [00:19:51]: What it will do is it will flush it out to s three based on different knobs that we would want to enable. Right? So what we did was like buffer memory size, message count, and then the active buffer time. So let's say from after every five minutes we would persist the data to s three. So this is something that we did for implementing custom data capture. And then let's go to how do downstream systems use this? So like we discussed training data and then the inference data. So the inference data is generated by the data capture. Monitoring job, takes both of these as input generates the metrics, pushes the metrics to your observability tool, and then you can set thresholds for these metrics that if they go over a certain value, you would send alerts, and then those alerts go to a notification queue, and then that notification queue can have subscribers downstream which on notification send an alert to pagerduty or slack. And then retraining can be done basically too.
Pushkar Garg [00:21:08]: So that sort of completes the loop on observability. So you started with data capture identified, created the metrics identified if there is anything wrong in the model, and then completed the loop with retraining. And then this data capture can also be used for running model quality. So you can imagine that the inference data can be augmented with the ground truth, and then ultimately that data can be joined with the data captured in production. So, so the monitoring job will have the input of for this set of features, what was the prediction that the model did and then what was the actual ground truth? And use those to compare and then generate model quality monitoring metrics as well. So this is how data capture overall comes together and then creates the foundation of running observability. What ML observability. So some of the lessons learned around implementing custom data capture is like batching requests is obviously a very good idea.
Pushkar Garg [00:22:16]: You would not want to persist every request that you get to s three because that hampers your IO performance. Experimenting with different knobs to find the right threshold is a good idea. Like we discussed, memory size, the total time after which the buffer would write data to s three, or the message count, whatever may work for you. But integrating those knobs into the code is a good idea. Sampling of requests is something that you can save money on, maybe not a lot, but like in terms of compute, there can be some savings. So if you there is not a need to capture everything. So in our case like 75% of the requests are also able to identify if there is any kind of data drift. And then performance benchmarking is also important because this is something that is running on the endpoint itself.
Pushkar Garg [00:23:14]: So it's important to be able to set some SLA around how much latency this data capture buffer may add. So it's important to do benchmarking to identify the average max and 99 percentile agency that may be added as part of the data capture buffer. Yeah, that's my time, I think. Thank you very much for listening in. Happy to answer any questions if there are.
Skylar [00:23:42]: Awesome. Thank you so much.