In particular we are trying to understand performance vs. cost trade-offs. We don't have a budget to train from scratch.
We are working with a proprietary data set on the order of 100M tokens and are looking to fine-tune a general purpose language model and also create task-specific models based on the same corpus.
Any help would be appreciated!
A single A100 or H100 with 80GB VRAM can fine tune 70B open models (and obviously scaling out to many nodes/GPUs is faster, or can use much cheaper GPUs for fine tuning smaller models.)
The localllama Reddit sub at https://www.reddit.com/r/LocalLLaMA/ is also an awesome community for the GPU poor :)
If you do end up wanting to fine tune then use qlora with axolotl or unsloth to prove your hypothesis on a smaller model and then evaluate if you want the marginal gains you get from full precision training.
After you fine tune it with 100m token dataset, use DPO to polish it off. You need to create a DPO dataset for that but it can be relatively small to get some great gains.
After that, look at applying grammars during inference if you are expecting structured results like json.
You should be able to run the experiments on 4090s from vast.ai or runpod or similar service.
It can cost less than $100 depending on your requirements.
https://fortune.com/2024/03/11/adaptive-startup-funding-falc...
Mistral 7b is 2x faster than HuggingFace + Flash Attention 2. Gemma 7b is 2.4x faster than HF + FA2.
Check out https://github.com/unslothai/unsloth for full benchmarks!
https://github.com/OpenAccess-AI-Collective/axolotl
Someone from one of the cloud GPU vendors wrote a guide: https://brev.dev/blog/fine-tuning-mistral
Theres no one size fits all answer yet, but if you just want to test it out there are many commercial offerings on which you should be able to get some results for under $10k.
and since your dataset is large, the longest context windows are insufficient.
Blazing fast compared to out-of-the-box transformers, also make sure to use flash attention if you have A100s or better and context length >= 2k
Add FAISS (https://github.com/facebookresearch/faiss) if you need fast local RAG