Tricks to Fine Tuning
speakers

Raj is an Assistant Professor of Computer Science at the University of California, San Diego, leading the PEARLS Lab in the Department of Computer Science and Engineering (CSE). He is also a Research Scientist at Mosaic AI, Databricks, where his team is actively recruiting research scientists and engineers with expertise in reinforcement learning and distributed systems.
Previously, he was part of the Mosaic team at the Allen Institute for AI. He earned his PhD in Computer Science from the School of Interactive Computing at Georgia Tech, advised by Professor Mark Riedl in the Entertainment Intelligence Lab.

At the moment Demetrios is immersing himself in Machine Learning by interviewing experts from around the world in the weekly MLOps.community meetups. Demetrios is constantly learning and engaging in new activities to get uncomfortable and learn from his mistakes. He tries to bring creativity into every aspect of his life, whether that be analyzing the best paths forward, overcoming obstacles, or building lego houses with his daughter.
SUMMARY
Prithviraj Ammanabrolu drops by to break down Tao fine-tuning—a clever way to train models without labeled data. Using reinforcement learning and synthetic data, Tao teaches models to evaluate and improve themselves. Raj explains how this works, where it shines (think small models punching above their weight), and why it could be a game-changer for efficient deployment.
TRANSCRIPT
Prithviraj Ammanabrolu [00:00:02]: I'm Prithviraj or I also go by Raj. I am an assistant professor at University of California, San Diego and also a research scientist at Databricks. I was previously with Mosaic before we got acquired. So I actually have, I have two jobs. I do two full time jobs effectively. And I usually take my coffee black but with sugar. So usually like no cream, but will have some amount of sugar in it.
Demetrios [00:00:37]: Welcome back to the MLFS community podcast. I'm your host, Demetrios. And today talking with Raj all about fine tuning, but not just any kind of fine tuning, we're talking about Tau fine tuning, the method that he came up with and wrote an incredible blog post about. Why don't we just start this conversation? Do you call it the Dao? Do you call it Dao or do you call it Tao?
Prithviraj Ammanabrolu [00:01:08]: Yeah, we're just going to call it Tao. It's gone through multiple naming changes internally. Tao is actually the final version, which is what ended up sticking, that we came up with like maybe a day before the blog post went live. So internally it's actually key to something else in my, in my head because for the majority of the time we were, we were thinking about it as a different name, but we, we decided that this was a better name overall.
Demetrios [00:01:39]: Because I, whenever I see spelling like that, I always think of like the Dao Te Ching, you know?
Prithviraj Ammanabrolu [00:01:46]: Yeah, yeah, I think they were just, yeah, we were just thinking like, like, Tao.
Demetrios [00:01:52]: Yeah, Tao is perfect. So what is it?
Prithviraj Ammanabrolu [00:01:59]: So Tao is basically at a high level, a way of people fine tuning their own models for their own domains. So there's always been this sort of stress between, well, do you have one model to rule them all or do you have a separate model for every single individual domain? And given that a lot of people's data is private, we kind of seem to be gravitating towards a world in which these aren't actually necessarily like entirely orthogonal, but they're very much in a world where a lot of people do need customized models for their data and their use cases. And so that's what it is. It's a way of doing that. And most importantly, it's a way of doing that without people having labels. Right. So like the, the bane of every single, like, you know, machine learning person trying to build a custom model is like, oh no, where am I going to get the labels for this data?
Speaker C [00:03:15]: Right.
Prithviraj Ammanabrolu [00:03:15]: Expensive. Yeah, these are expensive. The annotation costs are going to be super expensive. And we've heard this like time and time again from like so many customers, right? Like, we were. That's actually where some of the initial motivation came from. So we were thinking, okay, well, the initial version was like, okay, well what would it take to have some kind of a system where people are deploying their models, collecting feedback in some form with these models and then using those to continuously improve them through time? But not everyone has access to be able to like collect this, this sort of like feedback and whatnot in real time. And a lot of people just straight up, they have like an idea of the tasks they want to do. Right.
Prithviraj Ammanabrolu [00:04:12]: They have some prompts and whatnot, but they don't have a particularly like, large data set on how to do it.
Demetrios [00:04:19]: And so the basic. There was, okay, you don't have a large data set. You also don't have a very labeled data set or you don't have any labels because it takes a lot of time and effort to create those labels. So how can we still get something of value without any of that?
Prithviraj Ammanabrolu [00:04:40]: Yeah, and it's, it's a little bit like counterintuitive from like all the classical machines, like supervised learning perspectives, right? It's like, oh, we're talked like, you know, always think really hard about your data. You need good data. And then if you think about how a lot of the big companies are making their models too, they do have a ton of human annotated data. So that was the sort of initial key challenge that we spent a lot of time brainstorming through is like, okay, well if the wait tip get people to use something like this, right? Like this RLXF as a service. And so I'll get to the, I'll talk about the RL bits later. But basically the sort of like continuously improving, domain specific, specialized models thing, then we need to make the barrier to entry as small as possible.
Demetrios [00:05:42]: That's funny. Anybody that has taken the Andrew Yang course, he's like, you can't do that. No, it's all about the data. What are you talking about? No label data.
Prithviraj Ammanabrolu [00:05:54]: Yeah, well, it turns out that it kind of still is about the data. It's just like we abstract a lot of it away for you.
Demetrios [00:06:04]: Tell me more about that.
Prithviraj Ammanabrolu [00:06:05]: And the data is like, it's synthetically generated under the hood, so to speak. And it's synthetically generated from the model itself.
Speaker C [00:06:16]: Right.
Prithviraj Ammanabrolu [00:06:17]: So here's the way I like to think about it.
Speaker C [00:06:19]: Right.
Prithviraj Ammanabrolu [00:06:19]: So the way we do this is very reinforcement learning based. And the way I like to teach reinforcement learning usually is you think of this as like personalized, supervised Learning for the model, right. Where if you do normal supervised learning, you just have like a ton of labels that a human has kind of annotated and you're trying to learn how to mimic the human, human's labels or whatever. Whereas in the reinforcement learning case, the model is itself generating data.
Speaker C [00:06:58]: Right.
Prithviraj Ammanabrolu [00:06:59]: The model is, is making these generations. So you have a prompt. The model like generates like you know, one or n responses for this particular like prompt. And then it gets feedback. It's, it's like, okay, which one of these things was good? Which one of these things was not? And then it tries to learn from that.
Speaker C [00:07:20]: Right.
Prithviraj Ammanabrolu [00:07:20]: And so it's learning from its own mistakes here, right? So it's. Whereas like if you were training it entirely on like a human data, it would be learning from a human's mistakes. And the key aspect of this is the mistakes that models make when they're reasoning about something is different than how humans do it. So it actually turns out that in a lot of cases the human data is not necessarily goal truth. And so that's probably one of the key reasons this actually works is being able to have these models like learn from their own mistakes in like a way that's specialized for the, the models.
Demetrios [00:08:08]: Okay, so break down exactly how this is working and what. Because it's all, I think we jumped into it saying tao. The Tao stands for something. Right. And it's not the land of 10,000 things. It is test time. Adaptive optimization.
Prithviraj Ammanabrolu [00:08:27]: Yeah. So I guess at a higher level, right. So it's actually very much the process that I was just explaining. So people will come and they have a bunch of prompts, right? So that's the barrier to entry. Now some customer comes and they say, hey, we have these tasks. These are like the types of prompts that we want to give this model. And then what we do is we take those prompts internally. We have these like, you know, our like base models, these sort of like, you know, up to some extent pre trade LLMs or whatnot.
Prithviraj Ammanabrolu [00:09:07]: We have them generate a bunch of responses from the model, right? So they're like, there's various sorts of techniques on how to do this. But the general gist they all boil down to is you are trying to generate different responses. And then we have another model which is called a reward model. And this is pretty important. So what we have, this reward model will take these responses that you're generating model, which I'm going to call a policy from here on out to try to use more RL language. This reward model is going to take the responses that this policy has outputted and then score them. It's going to be like, this is plus 1 plus 0.5, minus 1 minus 0.5. And then via reinforcement learning, we can then train this policy to say, okay, well some of these responses were really good.
Speaker C [00:10:23]: Right.
Prithviraj Ammanabrolu [00:10:24]: And so because these responses were really good, we're going to change the weights of the model to upweight those.
Speaker C [00:10:32]: Right.
Prithviraj Ammanabrolu [00:10:32]: To make it more likely to produce outputs that are like the ones that have high scores and make the model less likely to produce outputs that are like the ones that have received low scores. Right. And then that process sort of like.
Demetrios [00:10:50]: Repeats iteratively, which until now is standard reinforcement learning.
Speaker C [00:10:56]: Right.
Prithviraj Ammanabrolu [00:10:56]: So this is mostly like standard reinforcement learning.
Speaker C [00:11:01]: Right.
Prithviraj Ammanabrolu [00:11:02]: What I've described, the key things that we've had to do in this particular case is, well one, if you think about it, this like reward model that we have is kind of providing labels and it's doing a lot of the job of ranking these inputs and these outputs of the policy. It's giving some kinds of scores. And so it's really important to have a reward model that is trained really well across a wide range of tasks. And so that's one of the things that we talk about in the blog post. We have this reward model called a DB REWARD dbrm. So a very enterprise focused reward model that's able to judge tasks across a wide variety of of possibilities. So the reason that this is simpler is that it is easier to train models to judge whether something is correct or not as opposed to actually trying to come up with the correct answer yourself.
Speaker C [00:12:25]: Right.
Prithviraj Ammanabrolu [00:12:27]: So this is this verification problem. And so the verification problem is easier than the generation problem, usually cheaper and cheaper.
Speaker C [00:12:39]: Right.
Prithviraj Ammanabrolu [00:12:40]: And so we spent a lot of effort like gathering the data to make as like wide ranging and generic a reward model as we possibly could. And then for the vast majority of use cases at least everything that like we've shown in the blog post, we use the same reward model across like, you know, text to SQL across like, you know, finance bench, which is like, you know, like the finance domain question answering kind of thing. It's actually the same reward model. It turns out that same reward model is able to judge outputs for like all of those tasks.
Demetrios [00:13:20]: The reward model was a fine tuned model that it was like some llama base or something that was fine tuned or it was a completely separate smaller model.
Prithviraj Ammanabrolu [00:13:35]: It was like another model that had that was like a base. So think like, you know, in the llama style or like something like Open source style, Llama, dvrx, whatever style, like base model that we did like fine tune on top of. So the conventional wisdom for training reward models is you usually do start with a pre trained and then somewhat instruction tuned model and then you add an additional head on top. So instead of mapping to this giving you logits of probability distribution across your entire vocabulary to figure out what token to predict next, um, it's giving you a single scale or value now.
Demetrios [00:14:21]: Yeah. Okay. Now the other piece that I wasn't super clear on here was that you're saying the reward model is basically giving you, if you squint some annotated data.
Prithviraj Ammanabrolu [00:14:37]: Kind of. It's, it's very sparse annotated data. Right. So it's, it's feedback data effectively. So the reason this is possible is. So let's look at this in terms of information density. So if you have pure imitation learning data, supervised fine tuning data, where you have to have a human write down exactly how to do a task, exactly all of the outputs for how to do a particular task, given the input prompt, that's very information dense. Whereas the information that this model is giving it is like, okay, you're generating this entire sequence after the entire sequence or at break points in between the sequence, we're just giving you a scalar.
Demetrios [00:15:34]: Value.
Prithviraj Ammanabrolu [00:15:36]: At, you know, predefined like breakpoint. So it's, it's a lot less bits of information that it needs to, to be able to predict and so that problem is easier.
Demetrios [00:15:48]: So now where does the test time annotation come in? I feel like yes, yeah, there's another part to this story.
Prithviraj Ammanabrolu [00:15:56]: So, so the other part to, to the story is the. Okay, well so you have like the, the reward that's like scoring stuff. Right, but what's it actually scoring?
Speaker C [00:16:10]: Right.
Prithviraj Ammanabrolu [00:16:10]: And what it's actually scoring is what's on the policy side. So the policy is generating like responses and the policy is generating. And you can generate one response and score that one response, right? And then learn from that. You can generate two responses and score that. And then the way you do it is also important. You can condition the second response that you generated based on the first one. It's like, okay, generate something different from what I generated the first time around or whatever. But if you think about it, that is inference time compute that's being used.
Prithviraj Ammanabrolu [00:16:54]: The more types of responses that you're using, the more types of responses you're generating, the more test time compute, so to speak is actually being used here. Now the terminology can get a little Bit confusing here. And the reason we used call this test time compute is this test time compute is just the model during training using additional inference time compute. Right. But once this like whole training process is done and then the model's like actually like deployed, it doesn't do multiple response generation or whatever.
Speaker C [00:17:42]: Right.
Prithviraj Ammanabrolu [00:17:43]: And so the inference cost for the actual user is the same. We are eating up a lot of the inference cost during the training process. I see the additional adaptive test time, that bit is eaten up during kind of ahead of time, so that when it's actually deployed, you can kind of expect the same inference latencies and whatnot.
Demetrios [00:18:15]: Yeah. So it's not like a deep research or whatever R1 reasoning model where you're going to give it a prompt and then come back 15 minutes later. It is doing that test time compute at the training level.
Prithviraj Ammanabrolu [00:18:35]: So, so the, the, the process would go, right? Like customer comes, they have a bunch of prompts that they give us and then we would do this kind of reasoning process at training time. And so we would just drop it in and then come back a couple days later or whatever amount of test time compute that we want to use for this particular task. These responses all get like scored in various ways. And it turns out like how you generate the responses, like also matters.
Speaker C [00:19:18]: Right.
Prithviraj Ammanabrolu [00:19:18]: So there's like very popular techniques like best of N, which is you just kind of like generate n different responses all iid, right. Independently sampled from each other. Turns out that's like probably not the greatest idea.
Speaker C [00:19:34]: Right.
Prithviraj Ammanabrolu [00:19:34]: It's like you're actually burning a lot of redundant compute. It works, but in terms of trying to be at least somewhat more compute efficient, it gets a little bit redundant. So if your underlying model doesn't have very much diversity in the range of things that it can say, then if you generate N responses, it'll turn out that K subset of those N responses will be just very similar to each other. And, and you've just burnt a bunch of. And like the scores for those will all be the same and there's no additional learning signal to be had and just trying the same thing over and over again. So you have to be a little bit smart about how you actually do this kind of sampling. And that's like one of the things that we like internally developed is like being smart about this because.
Demetrios [00:20:26]: It'S not the amount of ones or I guess to put that differently, you quickly realize that you, if you're scoring the same thing, basically it's the same words, but. Or it's the same idea and different words. And so it's getting the same score and that's not adding any richness to that base model.
Prithviraj Ammanabrolu [00:20:50]: Yeah, you're, you're, you're not like, extracting any, like, new information, right? From like an information perspective, you're kind of like dead in the water. So you have to be like, careful about how these responses are generated in the first place.
Speaker C [00:21:11]: Right.
Prithviraj Ammanabrolu [00:21:12]: And how you're doing this, like, exploration from the RO perspective such that the feedback that's given by the reward model is actually useful.
Demetrios [00:21:24]: I was figuring you were just going to turn the temperature up to 11 and then see what happens.
Prithviraj Ammanabrolu [00:21:29]: You know, we did, we did try that. That was actually probably one of the first things we tried is like, okay, let's just turn the temperature up and see. But it turns out that doesn't like, work particularly, like, well, and it kind of depends on the use cases. So, like, if you turn up the temperature really high, it's like actually, you know, if you're doing like creative writing, right, if you want character AI sort of chatbots, it could work. If you're trying to do enterprise things, you probably don't really want that level of super high temperature stochasticity. So you have to do things like, you have to be able to condition on previous responses. You have to make sure that you have some definition of. If you have one response and you have another response, you have to have a definition of how close or far apart from those two, like from each other they are.
Prithviraj Ammanabrolu [00:22:23]: And then make sure that the next responses that you generate are sufficiently different from all the previous responses that you've generated.
Demetrios [00:22:33]: Have you tried to do this over and over? And do you see loss when you've done it like two or three times? Because I can imagine a world where you say, all right, cool, let's just set up retraining pipeline that kicks off every week or every couple days or every month, whatever it may be. But then you start to see almost like compression loss in a way.
Prithviraj Ammanabrolu [00:22:57]: Yeah, yeah, that's a great question. So, like, the answer to most things in ML land, it depends, right? And so I'll tell you cases, it works in sort of some cases where it's not so great.
Speaker C [00:23:13]: Right.
Prithviraj Ammanabrolu [00:23:14]: So it turns out that actually, like, one of the things that we did in TAO is we just kind of like ran this process sort of like iteratively more and more, and we're like, oh, damn, the number is just going straight up.
Speaker C [00:23:30]: Right?
Prithviraj Ammanabrolu [00:23:30]: Like this graph. Yeah, this is the best graph ever. It's just going up and to the right, like, just continuously like this. This just keeps working. And this was like a year, year and a half ago. At this point, we're like, this is great. Our reward model works. All of this works.
Prithviraj Ammanabrolu [00:23:50]: But it turns out there's two things that could go wrong here. One is this obvious issue, or in hindsight, obvious issue, of reward hacking, where it turns out that your reward model is not entirely 100% accurate. There's some amount of noise in it. And so if you spend too much compute and you're just doing this looping around, at some point you've extracted all of the useful signal from the reward model and now you're kind of just learning the remaining error noise.
Speaker C [00:24:32]: Right?
Prithviraj Ammanabrolu [00:24:33]: Oh, interesting. And this is like, this is, you know, like, I don't know what overfitting really means anymore, but in like, for the. For a classically trained ML person, this would be the. The way to think of it is like, this would be like an RL ISH version possibly of overfitting, where you've just. You're really good at optimizing for the reward for a particular like model, but because that itself is just a proxy for what, like a customer or someone might want at the end, it what doesn't really do so well downstream.
Demetrios [00:25:09]: It's almost like, I think about when you're squeezing a lime and in the beginning you get all that juice, and then later you're like trying to squeeze, squeeze, squeeze, and it's so much effort just for like one drop. And really it would be even better.
Prithviraj Ammanabrolu [00:25:24]: Like it. Yeah, that starts getting bitter too.
Speaker C [00:25:26]: Boy.
Demetrios [00:25:27]: Yeah.
Prithviraj Ammanabrolu [00:25:28]: If you start squeezing at the end.
Demetrios [00:25:30]: Right.
Prithviraj Ammanabrolu [00:25:32]: It's somewhat well known, but it's really interesting to kind of see it in practice. And it's very much somewhat of an art to figure out what is that tolerance threshold, what is that cutoff? And how do you get that cutoff in a generic way such that multiple different customers can use a singular reward model?
Demetrios [00:25:58]: Because, well, what would you do then? Do you just swap out the base model and start over from zero again? Or how do you.
Prithviraj Ammanabrolu [00:26:08]: Well, I mean, you have checkpoints from sort of all through the training process, right? So you could roll back to a previous checkpoint. Ideally, the way this would work, right. The ideal loop of what this is would be you have some kind of generating model that burns a bunch of inference compute, and simultaneously you have this reward model that's trained, and then at some point you cut it off, you just deploy the model, and then you collect new feedback so that you can actually Update your reward model. And therefore now there's more signal in the REWARD model. And now that there's more signal in the reward model.
Speaker C [00:26:49]: Right.
Prithviraj Ammanabrolu [00:26:50]: You can now train again on it.
Speaker C [00:26:52]: Right.
Prithviraj Ammanabrolu [00:26:52]: That would be the, the, the ideal loop is you're doing this, like, cycle of like, redeployment. And this is what you were talking about, right? This like, redeployment thing and like retraining across, like, deployments only works really if you are able to get new signal in between each.
Speaker C [00:27:11]: Right.
Prithviraj Ammanabrolu [00:27:11]: So if you ask questions.
Demetrios [00:27:13]: Yeah, I get it. Yeah.
Prithviraj Ammanabrolu [00:27:15]: If it's the same questions, Right. And you've kind of optimized up to a certain extent already, you're not going to get anything like, more after the sort of like, initial training run, because we're already doing like a lot of that, like, optimization for you. And like, we're figuring out dynamically, like, what those, like, thresholds and stuff are. So, like, ideally, right, if somebody comes to us and they're like, okay, so, you know, we have these prompts, this is cool. We will then use this tau method overall, figure out all these sorts of thresholds, optimize against this reward model that we have dynamically, and then we'll give you. And the model goes back to them. The reason you would want to then retrain it further would be actually, we have these new tasks now that we want the model to be able to do. Here's a bunch of new prompts.
Prithviraj Ammanabrolu [00:28:14]: And we are reasonably confident that these new prompts are actually doing something different from the old prompts. And to use your own words, they're not like the same thing, but in different words, because we kind of count for that too. Sort of under the hood of taking your prompt data set and then just making sure we're covering all our bases of like, oh, here's all the ways that people could possibly say or ask for these same set of prompts. Yeah. And then, so if you have like a bunch of new tasks that you're reasonably certain is different from the capabilities that the model has already been trained for, then you come back to us and say, okay, well, we would like to optimize again, please. And then you do it Right. And that would be the reason to retrain the model.
Demetrios [00:29:07]: So what are some of the results that you saw when you did this? Like, compared against different ways of fine tuning, I guess, would be the main thing. And most of the time.
Speaker C [00:29:19]: Right.
Demetrios [00:29:20]: And I'm sure you've seen this a lot. It's just try and optimize the prompts try and prompt, tune it to the best of your ability because you get such a fast feedback loop. And then if you can't do that, then figure out if you need to go and, and you need to fine tune and then what kind of fine tuning with how many GPUs and all that, what data, what labels, all of that fun stuff. There's so much of a headache that comes along with fine tuning and not doing it.
Speaker C [00:29:48]: Right.
Demetrios [00:29:48]: And then later, after you fine tune you, you realize the model actually performs worse. What's going on? And so here I see the value prop really clearly. You say, look, the labeled data you don't need to worry about. And we'll actually just give us your prompts and we'll make a model that works better on these prompts.
Prithviraj Ammanabrolu [00:30:09]: Yeah, yeah. So again, you know, the standard machine learning like, adage of there's no free lunch is very much true. Right. So you know, these prompts can be from like multiple different tasks and we can make your model better at these like multiple different tasks, but you're trading off performance somewhere else.
Speaker C [00:30:35]: Right.
Prithviraj Ammanabrolu [00:30:35]: So like you're, you're like, you know, like most of the customers will come to us, they'll have like a bunch of like enterprise based like prompts. But you know, you're probably, you know, like reducing the model's like underlying creative writing.
Speaker C [00:30:50]: Right.
Demetrios [00:30:51]: Which they may or may not care about.
Prithviraj Ammanabrolu [00:30:52]: Yeah. Which they like, probably don't really care about.
Speaker C [00:30:55]: Right.
Prithviraj Ammanabrolu [00:30:55]: And if they do care about it, they can come back to us with like creative writing prompts as part of it and then we can optimize for that too.
Demetrios [00:31:03]: Give me my financial statements in the style of a Bob Dylan song.
Prithviraj Ammanabrolu [00:31:08]: Yeah, exactly. That'd be cool actually. Yeah. Maybe I should figure out ways of teaching students that way.
Demetrios [00:31:18]: But yeah, the other thing that just popped up there in my head as you were saying that is do you need a certain amount of diverse prompts or does it perform better if you give it a much bigger set of prompts? Because you were saying at the end of the day, all of the output that gets scored, you reach that threshold. So in my mind I'm wondering, oh, well, if the input is more vast, then does that mean that you can get more out on the other side? And I'm sure you looked at that.
Prithviraj Ammanabrolu [00:31:51]: Yes, we absolutely did look at that. I guess the high level here is that it is absolutely true that the bigger your prompt data set and the more diverse your prompt dataset, the better the performance is going to be. And for the longest Time. Actually, we only found that we were able to get good performance. At least good performance means we could train a relatively small parameter count model to match the performance of 4, 0 or something on a particular task. If we had a really big prompt set, a really big, diverse sort of a prompt set, but.
Demetrios [00:32:32]: And really big is like a thousand.
Prithviraj Ammanabrolu [00:32:35]: Yeah, big would be in the thousands.
Speaker C [00:32:38]: Okay, right.
Prithviraj Ammanabrolu [00:32:41]: But that's a pretty big barrier to entry. Again, if we're asking for no labels, but we're like, okay, only prompt. But actually the asterisk is that you have to give us 10,000 prompts or whatever, that's a pretty big barrier to entry. And so we don't require that. We found a way around that basically such that for the sort of end customer, they can give us just like a subsample of the kind of range of things that they think the model can be used for. And then we'll take it from there and we'll figure out how to get that kind of diverse prompt set or whatever and learn from it.
Demetrios [00:33:33]: Yeah, so it's like I come to you with my three prompts and you're.
Prithviraj Ammanabrolu [00:33:37]: Like, hopefully you have more than three. But like, yeah, but you know, like, if you come to me with like only three prompts, like I, we would do a great job of that. Make our life easy is like, I just need it to do well on like these three types of prompts. I was like, okay, sure, right. But you know, like, you know, we'd be okay with that range. Right? It's like, oh, like, you know, 10 to 50 prompts even.
Speaker C [00:34:08]: Right?
Prithviraj Ammanabrolu [00:34:09]: Like that range of stuff. And then we can go from there and internally, like get to the stage where we see prompts that are a lot more along this line and you know, much bigger, sort of like data set, much more diverse data set, so on and so forth.
Demetrios [00:34:33]: And so the thing that is also fascinating here is that you were able to get such high performance from small models and I wonder if you had no real constraints on how you were to do it. But I just told you, hey, smallest model, best performance, highest accuracy on X amount of tasks. How would you go about it? And it can be anything from, all right, I'm going to distill this model and then I'm going to fine tune it and then I'm going to whatever you want. Like, how would you look at that and try? Or is it not even worth the headache of doing all these extra hoops that you jump through and you just say, I'm going to do tau, then it's going to be good enough.
Prithviraj Ammanabrolu [00:35:26]: I actually think, like, I'm going to do Tao, and it's probably going to be good enough because I think we've done a pretty good job of figuring out what the sort of optimal scenario is, even given kind of, like, all of these constraints. Right. So these constraints are fairly realistic, but under the hood, we've. We're kind of, like, relatively unconstrained. Right. We have no compute constraints and things like that under the hood. And so I think we've done a good job of being unconstrained as much as we can on our side, whereas still accepting constraints and making sure that the customers don't have a particularly high barrier to entry. Okay.
Prithviraj Ammanabrolu [00:36:20]: Yeah. So I actually wouldn't change too much, honestly. I think. I think, like, if I'm thinking about it, if you're like, oh, you're, like, fully constrained, how would I do it? That's actually kind of the question I asked myself and. Because that's what would be ideal, right? It's like, oh, because I. I guess the. The worst feeling would be like, oh, like, somebody will come to you, and they'll be like, oh, okay, well, you know, we have this particular task. Like, you know, we want to.
Prithviraj Ammanabrolu [00:36:49]: We want to train models on. On this task or whatever. And then we're like, okay, here's this model that, like, kind of works. Asterisk.
Speaker C [00:37:00]: Right?
Prithviraj Ammanabrolu [00:37:01]: But if we were more unconstrained, it would do better.
Speaker C [00:37:04]: Right?
Prithviraj Ammanabrolu [00:37:04]: And so, like, that's. That's a scenario I wanted to avoid.
Demetrios [00:37:07]: Yeah, yeah, yeah, I hear you. So the other interesting piece here is. All right, does the. How small of a model can you get to do how great of things or how good can you get these small models?
Prithviraj Ammanabrolu [00:37:24]: Yeah, yeah. Again, it's a no free lunch question. But it turns out that the smaller models, like, the sort of like, 8B range of stuff, you can get them to do pretty darn impressive things.
Speaker C [00:37:43]: Right.
Prithviraj Ammanabrolu [00:37:44]: And again, it really depends on the difficulties of the actual tasks that people are trying to do. But what we found is that in a narrow range, if you take away its broad intelligence and specify it for more narrow data intelligence, so to speak, the small models are actually perfectly fine, the relatively smaller models. So there's a threshold at which they're good enough. Now, obviously, you might be like, oh, Raj, does that mean you don't believe in scaling?
Demetrios [00:38:28]: That's like.
Prithviraj Ammanabrolu [00:38:28]: That's not true. I do believe in scaling. Right. I just like, scaling is like an axis that, you know, the more complicated a task you have the bigger the model that you will need in order to be able to like achieve it. But not everyone needs. So like you see the new Llama force, right. The behemoth is 2 trillion parameters. You know, only a fraction of that are like, are active.
Prithviraj Ammanabrolu [00:38:59]: But that's still going to be like a massive pain to deploy for the vast majority of people. Yeah, People are not stoked. People are not stoked about the size of the behemoth. Even like scout is relatively large.
Speaker C [00:39:14]: Right?
Demetrios [00:39:14]: Yeah.
Prithviraj Ammanabrolu [00:39:15]: And the performance is. Yeah. So people have these kinds of costs quality trade offs and every single customer I guess lies somewhere different on that Pareto frontier of cost quality. And so when you're thinking about setting up any sort of a product where you're trying to get people to make their own customized models, things like adaptive compute or whatnot are great because they basically allow them to control where on the the cost quality like frontier they are. It's like, oh, spend more cost, burn more compute during training, get better quality at the end.
Demetrios [00:40:10]: Yeah, yeah, that's a great point is how okay you may need to optimize for a certain place in this whole life cycle on. You're okay with spending a little bit more money on that training to fine tune because later it means now you have that small model and I imagine have you played around with trying to like compress these models or distill these models even more to make them even smaller?
Prithviraj Ammanabrolu [00:40:43]: Um, we haven't like, yes, internally, like we've played around a little bit of like distilling it down from like the 8B size even further. And you can still do pretty reasonably well for certain tasks on those. But you start seeing pretty sharp drop offs if you're doing a 3B like a 1B or like a 3B sort of a model even for some of the benchmarks that we were thinking of or that we showed in the blog post, like the text to SQL like Bird SQL and whatnot. Um, the drop offs were, were pretty high, but for some of the other benchmarks it worked okay.
Speaker C [00:41:24]: Right.
Prithviraj Ammanabrolu [00:41:24]: So I guess it's like again it, it kind of depends a little bit on how much or like how difficult the, the task is. But my, my impression is that the difference between 3B and 8B in terms of performance is disproportionately high compared to the cost difference.
Speaker C [00:41:50]: Oh, wow.
Demetrios [00:41:51]: Okay. Huh.
Prithviraj Ammanabrolu [00:41:53]: Yeah. Which is kind of why we presented a lot of results at the 8B scale. Right. So it's like, oh, it was not a linear correlation. Like the performance dropped off A lot more than like, you know, say like, you know, 3, 8 of the compute or whatever.
Demetrios [00:42:11]: Yeah. And the cost wasn't outrageously different.
Prithviraj Ammanabrolu [00:42:14]: Yeah, the cost isn't like outrageously different between like a 3B and like an a B.
Speaker C [00:42:21]: Right.
Prithviraj Ammanabrolu [00:42:22]: Now, of course there's like all sorts of like inference time inference things that you can do. You can do like multi tenant serving, like all that kind of jazz. But like we felt that this was like a reasonable trade off, like an 8B was like a reasonable trade off. That is still fast and lightweight enough to deploy in a lot of scenarios, but it's still able to perform pretty well. And the beautiful thing about TAO is that it just works. It's size independent.
Speaker C [00:42:57]: Right.
Prithviraj Ammanabrolu [00:42:58]: We have results on 7dB. We have results internally on 405B. And it just keeps working even better and better.
Demetrios [00:43:08]: Like, did you notice any drastic differences or notable differences, maybe not drastic, between the models that you were using to fine tune on and some took better to this method than others?
Prithviraj Ammanabrolu [00:43:25]: Yes, kind of. So it really boils down to the models being able to generate diverse responses. So there are certain models that are out there that have kind of been RLHF'd to death.
Speaker C [00:43:47]: Right.
Prithviraj Ammanabrolu [00:43:49]: Okay. My primary field of research is rl. And so what is RLHF really doing here underneath if you have some kind of distribution? So the pre training lets you learn all these token distributions over the entire Internet. Language distribution, supervised fine tuning makes that distribution smaller. It's like, okay, well I'm going to bias you towards this range of instructions that I think the model is most likely to hear people say to it and interact with it. Now, like RLHF, there's been a lot of sort of contemporary work that's been out in the open, which it shows that it makes these probability distributions super spiky.
Speaker C [00:44:38]: Right.
Prithviraj Ammanabrolu [00:44:39]: It's like instead of now being able to generate a wide range of responses, you are now super constrained.
Speaker C [00:44:48]: Right.
Prithviraj Ammanabrolu [00:44:48]: You can only say one thing. And this is why you get all these responses of the safety types of responses, like, I can't do that for you. Yeah, right. Like it's kind of been beaten out the model. That's not great from the perspective of doing further optimization with rl. Right. So if you have a reward, like I was saying, right. A lot depends on being able to extract good training signal from this really great reward model that we have.
Prithviraj Ammanabrolu [00:45:24]: And which means that the more diverse the underlying sort of like base model that we're using is able to generate responses, the better it is.
Speaker C [00:45:37]: Right.
Prithviraj Ammanabrolu [00:45:39]: So that's part of the trick is like, you know, don't use things that are already like to RLHF. They won't take particularly well to this method because everyone's, they've already been sort of like, you know, rl trained for a different set of tasks and now they're no longer particularly great at this.
Demetrios [00:46:01]: Do you remember way back in the day, two years ago, when the Mistral first, the first Mistral model came out and it had pretty much no RLHF on it? And I think that's why people loved it so much, was because it was like, whoa, we get to do it ourselves. That's amazing.
Prithviraj Ammanabrolu [00:46:22]: Yeah. And there's pros and cons of releasing a model with no RLHF, but it certainly became really popular in certain subsets of the open source community that wanted to use it for all sorts of use cases that a model like Llama was kind of too RLHF to do well on. Right, but yeah, so that's kind of the thing to think about here. Right, but. Well, I guess this is from like a researchy kinds of perspective, from a customer perspective. They don't really have to think about it because we've already thought about it for them.
Demetrios [00:47:06]: Yeah, you've already tested a bunch of different models.
Prithviraj Ammanabrolu [00:47:09]: Yeah, we already have a pretty good idea of like the, the distributions and which models are spiky, which are not. If they are spiky, how to get rid of some of the spikiness in the distribution such that it's more useful to generate a diverse range of things. We've kind of already taken care of this with internal research.
Demetrios [00:47:34]: Okay. We've been dancing around this and really talking about how the sausage is made. But what we didn't talk about is the actual results. So maybe you can give us a quick TLDR of like, like how is this better? In what ways I guess is the. Is the big question that's going through my mind right now.
Prithviraj Ammanabrolu [00:47:57]: Yeah, yeah. So. So the one thing that it's like super obvious, right? The, the very obvious, like baseline is just like supervised fine tuning. Right. It's like if you had supervised fine tuning, like labels, it turns out that it does way better than supervised fine tuning with labels. Even though this doesn't have labels.
Speaker C [00:48:20]: Right.
Prithviraj Ammanabrolu [00:48:21]: That's pretty clear. We tested it on text to SQL and then some other benchmarks that we developed that other people in my team have developed previously, which is like the enterprise arena and we have our own version of like finance bench and whatnot which test all these sorts of like, you know, question answering, open and closed book question answering scenarios in very like enterprise Y ways. Like I would never go back to like supervised fine tuning.
Speaker C [00:48:52]: Right?
Prithviraj Ammanabrolu [00:48:53]: Like it, it's. I think, right, I think it's, it's this is like just way better. Unless you have like a ridiculously like massive data set of, of labels that's like rich and whatnot, which like effectively nobody does. And even then like TAO is not orthogonal to that. Like you could still run TAO on top of that.
Speaker C [00:49:21]: Oh, nice, right.
Prithviraj Ammanabrolu [00:49:23]: Because if you have like a label data set with prompts and responses, guess what you also have, you have the prompts, which is all TAO needs.
Speaker C [00:49:32]: Right?
Prithviraj Ammanabrolu [00:49:33]: So you can still run TAO together with the supervised fine tuning step up front. The other interesting thing, interesting thing is that the models are actually like I'm going back and looking at the sort of exact numbers, but it turns out that even the 8B models end up being comparable to 4.0 for that narrow task for that range of tasks. On some tasks, the AP models, if you have a task like finance bench or whatnot, will do as well as 4.0 or 03 like mini or whatever, which are like an order, at least order to two orders of magnitude bigger in parameter count. Now for the tasks that are like a little bit more complicated, you see like, you know, for like Bird SQL or whatever, which yeah, is more complicated of a task, you see that the performance of like the AP models are still slightly less than like 4.0 or 03 mini, but that gap narrows and basically vanishes when you get to 70B.
Demetrios [00:50:58]: Oh, nice.
Speaker C [00:51:01]: Right, yeah.
Prithviraj Ammanabrolu [00:51:02]: And so this is what I was talking about is it's where you are on the sort of cost quality trade off what would be good enough for you. And so if you are a person that is doing like financial QA in the line of like finance bench, AP is probably good enough for you, right? If you're doing something significantly like more complicated, right. You'd probably need something that's like a little bit like larger sized in order to be able. And like if you're doing something that's like a little bit more complicated and you want it absolutely at like O3 mini level, right. Then you're going to need to spend a little bit more on training it. The interesting thing I will say also is the comparison that we did do. We didn't show results of the actual inference time comparison costs. We just presented them as all equal.
Prithviraj Ammanabrolu [00:52:05]: But O3 Vinnie is really doing test time inference, right? They're doing a lot of additional compute, whereas we're, we've baked a lot of that additional compute in during training time itself. And so the real correct sort of to speak comparison would be probably like 4.0 in terms of test time latencies after the model is deployed. Yeah. So it's like what's your threshold? How complicated is the task that you're doing? But the results are pretty promising in the sense that with a 70B model, even the 70B is still smaller than 4.0 and 03. Midi and all of these things you can erase and basically get to near parity. And then now you control your fate, so to speak. You have this exact model. Nobody does anything under the hood that changes the model behavior.
Prithviraj Ammanabrolu [00:53:20]: The models don't swap out under the API hood. If you give a particular prompt one day and the model does a certain way, you can build guarantees around that the model is going to continue behaving the same way because you.
Demetrios [00:53:36]: No depreciating a model after a six month stint.
Prithviraj Ammanabrolu [00:53:40]: No depreciating a model. Yeah, like 4.5 is like depreciating already, which is crazy, dude.
Demetrios [00:53:46]: Well, this is, this is awesome, man.