LLaMA-Factory with Flash Attention 2 and Unsloth

It was tough to get this working, but I think I’ve figured it out enough to share. Here’s a quick guide on how to set up LLaMA-Factory with support for Flash Attention 2 and Unsloth training on Windows. This is using a RTX3060 12GB GPU, Windows 10, and CUDA 12.1.

Unsloth is an optimization library that claims up to a 2x performance boost with no trade off in accuracy. There’s also a quick and dirty script to convert bulk raw text to a dataset file, and a little overview of the dataset setup. In the video, I also touch on how to fix the error when loading the trained adapter in the Text Generation WebUI caused by mismatching PEFT libs.

Install Microsoft Visual C++ Build Tools

https://visualstudio.microsoft.com/visual-cpp-build-tools

Select Desktop development with C++, .NET desktop build tools, Windows application development build tools and ensure the packages match the check boxes on the right side.

Create a Conda Environment and Install Packages

conda create -n unsloth_env python=3.10 pytorch-cuda=12.1 pytorch cudatoolkit -c pytorch -c nvidia -c xformers -y
conda activate unsloth_env

Install Clang

conda install -y -c conda-forge/label/llvm_rc clangdev

Set CC variable

where clang.exe
set CC=[pathname]

Clone LLaMA-Factory Repo

https://github.com/hiyouga/LLaMA-Factory

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git

Install LLaMA-Factory Requirements

pip install transformers==4.41.2 datasets==2.19.2 accelerate==0.30.1 peft==0.11.1 trl==0.9.4 bitsandbytes xformers==0.0.27.dev792

Install PyTorch for CUDA 12.1

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Install Flash Attention 2

https://github.com/bdashore3/flash-attention/releases

pip install flash_attn-2.5.9.post1+cu122torch2.3.1cxx11abiFALSE-cp310-cp310-win_amd64.whl

Install LLaMA-Factory and Dependencies

cd\LLaMA-Factory
pip install -e ".[gptq,awq,metrics]"

Install Unsloth

https://github.com/unslothai/unsloth

pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

Login to Huggingface and Set Access Token

pip install --upgrade huggingface_hub
huggingface-cli login

Downgrade Numpy

pip install -U numpy==1.26.4

Install Windows Triton 2.1.0

https://huggingface.co/madbuda/triton-windows-builds/tree/main

pip install triton-2.1.0-cp310-cp310-win_amd64.whl

Example Training Command with Flash Attention 2 and Unsloth

python ./src/train.py --stage pt --do_train --seed 1337 --flash_attn fa2 --upcast_layernorm True --model_name_or_path unsloth/mistral-7b-instruct-v0.3-bnb-4bit --quantization_bit 4 --dataset c4_demo --template mistral --dataset_dir ./data --finetuning_type lora --lora_target q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,embed_tokens,lm_head --output_dir saves/c4_demo/mistral-7b-instruct-v0.3-bnb-4bit/lora/pretrain/4-bit --cutoff_len 2048 --preprocessing_num_workers 8 --per_device_train_batch_size 2 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --lr_scheduler_type cosine --logging_steps 1  --warmup_steps 10 --save_steps 100 --eval_steps 100 --evaluation_strategy steps --load_best_model_at_end  --learning_rate 0.0001 --ddp_timeout 180000000 --num_train_epochs 3.0 --val_size 0.001 --plot_loss --fp16 --use_unsloth

Text Files to Alpaca Format JSON Script

import os
import json
from tqdm import tqdm
import nltk
import chardet
from nltk.tokenize import sent_tokenize
def read_chunks(filepath, chunk_size=2000, overlap=48):
    """
    Reads a text file, chunks it by sentences, and returns a list of chunks.
    Prepends the first line and filename to each chunk.

    Args:
        filepath: Path to the text file.
        chunk_size: Number of words (approximate) in each chunk.
        overlap: Number of words to overlap between chunks.

    Returns:
        A list of strings, where each string represents a chunk of text
        prepended with the first line and filename.
    """
    chunks = []
#    with open(filepath, 'r', encoding='utf-8') as f:
    with open(filepath, 'rb') as rawdata:
        print(filepath)
        result = chardet.detect(rawdata.read())
        encoding = result['encoding'] if result['encoding'] else 'utf-8'  # Default to UTF-8 if not detected
    with open(filepath, 'r', encoding=encoding) as f: 
        print(filepath)
    with open(filepath, 'rb') as rawdata:
        print(filepath)
        text_bytes = rawdata.read()
        text = text_bytes.decode('utf-8', errors='replace')  # Decode with error handling
        first_line = rawdata.readline().strip()  # Read and store first line
    filename = os.path.basename(filepath)  # Extract base filename from filepath
    sentences = sent_tokenize(text)
    current_chunk = ""
    for sentence in sentences:
        if len(current_chunk.split()) + len(sentence.split()) <= chunk_size:
            current_chunk += " " + sentence
        else:
            chunks.append(f"{filename}: {first_line}\n{current_chunk.strip()}")
            current_chunk = sentence

    if current_chunk:
        chunks.append(f"{filename}: {first_line}\n{current_chunk.strip()}")
    return chunks


def process_directory(directory, output_file):
    """
    Processes all .txt files in a directory, chunks them, and saves as a single JSON.
    Prepends the first line and filename to each chunk.

    Args:
        directory: Path to the directory containing text files.
        output_file: Path to the output JSON file.
    """
    all_data = []
    num_files = len(
        [f for f in os.listdir(directory) if f.lower().endswith(".txt")]
    )
    with tqdm(total=num_files, desc="Processing Files") as pbar:
        for filename in os.listdir(directory):
            if filename.lower().endswith(".txt"):  # Check for lowercase extension
                filepath = os.path.join(directory, filename)
                chunks = read_chunks(filepath)
                for chunk in chunks:
                    data = {"text": chunk} 
                    all_data.append(data)
                    print(data)
            pbar.update(1)  # Update progress bar for each processed file

    with open(output_file, 'w') as f:
        json.dump(all_data, f)


# Example usage
directory = "C:\\INPUT\\DIRECTORY\\"
output_file = "OUTPUT FILE .JSON"
process_directory(directory, output_file)

Leave a Reply

Your email address will not be published. Required fields are marked *