ResearchTrend.AI
  • Papers
  • Communities
  • Events
  • Blog
  • Pricing
Papers
Communities
Social Events
Terms and Conditions
Pricing
Parameter LabParameter LabTwitterGitHubLinkedInBlueskyYoutube

© 2025 ResearchTrend.AI, All rights reserved.

  1. Home
  2. Papers
  3. 2403.11366
22
1

JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

17 March 2024
Anique Tahir
Lu Cheng
Huan Liu
ArXivPDFHTML
Abstract

The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.

View on arXiv
Comments on this paper