DiffRhythm – Fast, Full-Length Song Generation

A look at DiffRhythm, a diffusion model for generative music. This one is FAST. I’m looking over the demos, sharing some installation notes, trying some demo generations, seeing what works and what doesn’t, and trying to make a decent sounding tune. This is not a detailed tutorial.

DiffRhythm Huggingface demohttps://huggingface.co/spaces/ASLP-lab/DiffRhythm

DiffRhythm demo pagehttps://aslp-lab.github.io/DiffRhythm.github.io/

DiffRhythm GitHub https://github.com/ASLP-lab/DiffRhythm

Modified the inference scripts to add timestamps, writing meta info to .txt, added multiple generations, and (hopefully) loading the correct model for the correct length generation.

Please note: I am note affiliated with this project. Do not direct questions about my terribly written modifications to them. They have nothing to do with this.

Modified infer.py

# Copyright (c) 2025 ASLP-LAB
#               2025 Huakang Chen  (huakang@mail.nwpu.edu.cn)
#               2025 Guobin Ma     (guobin.ma@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import time
import datetime

import torch
import torchaudio
from einops import rearrange

print("Current working directory:", os.getcwd())

from infer_utils import (
    decode_audio,
    get_lrc_token,
    get_negative_style_prompt,
    get_reference_latent,
    get_style_prompt,
    prepare_model,
)


def inference(
    cfm_model,
    vae_model,
    cond,
    text,
    duration,
    style_prompt,
    negative_style_prompt,
    start_time,
    chunked=False,
):
    with torch.inference_mode():
        generated, _ = cfm_model.sample(
            cond=cond,
            text=text,
            duration=duration,
            style_prompt=style_prompt,
            negative_style_prompt=negative_style_prompt,
            steps=32,
            cfg_strength=4.0,
            start_time=start_time,
        )

        generated = generated.to(torch.float32)
        latent = generated.transpose(1, 2)  # [b d t]

        output = decode_audio(latent, vae_model, chunked=chunked)

        # Rearrange audio batch to a single sequence
        output = rearrange(output, "b d n -> d (b n)")
        # Peak normalize, clip, convert to int16, and save to file
        output = (
            output.to(torch.float32)
            .div(torch.max(torch.abs(output)))
            .clamp(-1, 1)
            .mul(32767)
            .to(torch.int16)
            .cpu()
        )

        return output


def generate_unique_filename():
    """Generate a unique filename with date and timestamp."""
    now = datetime.datetime.now()
    timestamp = now.strftime("%Y%m%d_%H%M%S")
    return f"output_{timestamp}"


def save_lyrics_metadata(lyrics_content, output_path):
    """Save lyrics to a text file with the same base filename."""
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(lyrics_content)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--lrc-path",
        type=str,
        help="lyrics of target song",
    )  # lyrics of target song
    parser.add_argument(
        "--ref-prompt",
        type=str,
        help="reference prompt as style prompt for target song",
        required=False,
    )  # reference prompt as style prompt for target song
    parser.add_argument(
        "--ref-audio-path",
        type=str,
        help="reference audio as style prompt for target song",
        required=False,
    )  # reference audio as style prompt for target song
    parser.add_argument(
        "--chunked",
        action="store_true",
        help="whether to use chunked decoding",
    )  # whether to use chunked decoding
    parser.add_argument(
        "--audio-length",
        type=int,
        default=95,
        choices=[95, 285],
        help="length of generated song",
    )  # length of target song
    parser.add_argument(
        "--repo_id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="infer/example/output",
        help="output directory fo generated song",
    )  # output directory of target song
    parser.add_argument(
        "--num-attempts",
        type=int,
        default=1,
        help="number of generation attempts to run",
    )  # number of generation attempts
    args = parser.parse_args()

    assert (
        args.ref_prompt or args.ref_audio_path
    ), "either ref_prompt or ref_audio_path should be provided"
    assert not (
        args.ref_prompt and args.ref_audio_path
    ), "only one of them should be provided"

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.mps.is_available():
        device = "mps"

    audio_length = args.audio_length
    if audio_length == 95:
        max_frames = 2048
    elif audio_length == 285:  # current not available
        max_frames = 6144

    cfm, tokenizer, muq, vae = prepare_model(max_frames, device, repo_id=args.repo_id)

    if args.lrc_path:
        with open(args.lrc_path, "r", encoding='utf-8') as f:
            lrc = f.read()
    else:
        lrc = ""
    
    # Create output directory if it doesn't exist
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)
    
    # Get additional metadata for saving
    ref_source = args.ref_prompt if args.ref_prompt else args.ref_audio_path
    metadata = (
        f"Lyrics source: {args.lrc_path}\n"
        f"Reference source: {ref_source}\n"
        f"Audio length: {args.audio_length} seconds\n"
        f"Model: {args.repo_id}\n"
        f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
        f"LYRICS:\n{lrc}"
    )
    
    # Run the generation N times based on num-attempts
    for attempt in range(1, args.num_attempts + 1):
        print(f"Starting generation attempt {attempt} of {args.num_attempts}")
        
        # Get necessary components for generation
        lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
        
        if args.ref_audio_path:
            style_prompt = get_style_prompt(muq, args.ref_audio_path)
        else:
            style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
        
        negative_style_prompt = get_negative_style_prompt(device)
        latent_prompt = get_reference_latent(device, max_frames)
        
        # Generate unique filename for this attempt
        unique_filename = generate_unique_filename()
        if args.num_attempts > 1:
            unique_filename = f"{unique_filename}_attempt{attempt}"
        
        # Run inference
        s_t = time.time()
        generated_song = inference(
            cfm_model=cfm,
            vae_model=vae,
            cond=latent_prompt,
            text=lrc_prompt,
            duration=max_frames,
            style_prompt=style_prompt,
            negative_style_prompt=negative_style_prompt,
            start_time=start_time,
            chunked=args.chunked,
        )
        e_t = time.time() - s_t
        print(f"Inference attempt {attempt} cost {e_t:.2f} seconds")
        
        # Save audio file with unique filename
        audio_output_path = os.path.join(output_dir, f"{unique_filename}.wav")
        torchaudio.save(audio_output_path, generated_song, sample_rate=44100)
        print(f"Saved audio to: {audio_output_path}")
        
        # Save lyrics and metadata to corresponding text file
        lyrics_output_path = os.path.join(output_dir, f"{unique_filename}.txt")
        attempt_metadata = metadata + f"\n\nAttempt: {attempt} of {args.num_attempts}\nGeneration time: {e_t:.2f} seconds"
        save_lyrics_metadata(attempt_metadata, lyrics_output_path)
        print(f"Saved lyrics and metadata to: {lyrics_output_path}")
    
    print(f"All {args.num_attempts} generation attempts completed successfully.")

Modified infer_utils.py

# Copyright (c) 2025 ASLP-LAB
#               2025 Huakang Chen  (huakang@mail.nwpu.edu.cn)
#               2025 Guobin Ma     (guobin.ma@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import librosa
import random
import json
from muq import MuQMuLan
from mutagen.mp3 import MP3
import os
import numpy as np
from huggingface_hub import hf_hub_download

from sys import path
path.append(os.getcwd())

from model import DiT, CFM


def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
    downsampling_ratio = 2048
    io_channels = 2
    if not chunked:
        return vae_model.decode_export(latents)
    else:
        # chunked decoding
        hop_size = chunk_size - overlap
        total_size = latents.shape[2]
        batch_size = latents.shape[0]
        chunks = []
        i = 0
        for i in range(0, total_size - chunk_size + 1, hop_size):
            chunk = latents[:, :, i : i + chunk_size]
            chunks.append(chunk)
        if i + chunk_size != total_size:
            # Final chunk
            chunk = latents[:, :, -chunk_size:]
            chunks.append(chunk)
        chunks = torch.stack(chunks)
        num_chunks = chunks.shape[0]
        # samples_per_latent is just the downsampling ratio
        samples_per_latent = downsampling_ratio
        # Create an empty waveform, we will populate it with chunks as decode them
        y_size = total_size * samples_per_latent
        y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
        for i in range(num_chunks):
            x_chunk = chunks[i, :]
            # decode the chunk
            y_chunk = vae_model.decode_export(x_chunk)
            # figure out where to put the audio along the time domain
            if i == num_chunks - 1:
                # final chunk always goes at the end
                t_end = y_size
                t_start = t_end - y_chunk.shape[2]
            else:
                t_start = i * hop_size * samples_per_latent
                t_end = t_start + chunk_size * samples_per_latent
            #  remove the edges of the overlaps
            ol = (overlap // 2) * samples_per_latent
            chunk_start = 0
            chunk_end = y_chunk.shape[2]
            if i > 0:
                # no overlap for the start of the first chunk
                t_start += ol
                chunk_start += ol
            if i < num_chunks - 1:
                # no overlap for the end of the last chunk
                t_end -= ol
                chunk_end -= ol
            # paste the chunked audio into our y_final output audio
            y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
        return y_final

def prepare_model(max_frames, device, repo_id="ASLP-lab/DiffRhythm-base"):
    """
    Prepare the model based on the repository ID and max frames.
    
    Args:
        max_frames: Maximum number of frames (2048 for 95s, 6144 for 285s audio)
        device: Device to load the model on ('cpu', 'cuda', 'mps')
        repo_id: HuggingFace repository ID for the model
        
    Returns:
        A tuple of (cfm_model, tokenizer, muq, vae)
    """
    # Determine appropriate repository ID based on max_frames
    model_repo_id = repo_id
    if max_frames > 2048 and repo_id == "ASLP-lab/DiffRhythm-base":
        # If using longer audio with base model repo_id, switch to full model
        model_repo_id = "ASLP-lab/DiffRhythm-full"
        print(f"Using full model repository ({model_repo_id}) for {max_frames} frames")
    
    # Prepare CFM model
    dit_ckpt_path = hf_hub_download(
        repo_id=model_repo_id, filename="cfm_model.pt", cache_dir="./pretrained"
    )
    dit_config_path = "./config/diffrhythm-1b.json"
    with open(dit_config_path) as f:
        model_config = json.load(f)
    dit_model_cls = DiT
    
    # Initialize the model using the original structure
    cfm = CFM(
        transformer=dit_model_cls(**model_config["model"], max_frames=max_frames),
        num_channels=model_config["model"]["mel_dim"],
        max_frames=max_frames
    )
    cfm = cfm.to(device)
    cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)

    # Prepare tokenizer
    tokenizer = CNENTokenizer()

    # Prepare MuQ
    muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
    muq = muq.to(device).eval()

    # Prepare VAE
    vae_ckpt_path = hf_hub_download(
        repo_id="ASLP-lab/DiffRhythm-vae",
        filename="vae_model.pt",
        cache_dir="./pretrained",
    )
    vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)

    return cfm, tokenizer, muq, vae

# for song edit, will be added in the future
def get_reference_latent(device, max_frames):
    return torch.zeros(1, max_frames, 64).to(device)

def get_negative_style_prompt(device):
    file_path = "./infer/example/negative_prompt.npy"
    vocal_stlye = np.load(file_path)
    
    vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
    vocal_stlye = vocal_stlye.half()
    
    return vocal_stlye

def get_audio_style_prompt(model, wav_path):
    vocal_flag = False
    mulan = model
    audio, _ = librosa.load(wav_path, sr=24000)
    audio_len = librosa.get_duration(y=audio, sr=24000)
    
    if audio_len <= 1:
        vocal_flag = True
    
    if audio_len > 10:
        start_time = int(audio_len // 2 - 5)
        wav = audio[start_time*24000:(start_time+10)*24000]
    
    else:
        wav = audio
    wav = torch.tensor(wav).unsqueeze(0).to(model.device)
    
    with torch.no_grad():
        audio_emb = mulan(wavs = wav) # [1, 512]
        
    audio_emb = audio_emb.half()

    return audio_emb, vocal_flag


@torch.no_grad()
def get_style_prompt(model, wav_path=None, prompt=None):
    mulan = model

    if prompt is not None:
        return mulan(texts=prompt).half()

    ext = os.path.splitext(wav_path)[-1].lower()
    if ext == ".mp3":
        meta = MP3(wav_path)
        audio_len = meta.info.length
    elif ext in [".wav", ".flac"]:
        audio_len = librosa.get_duration(path=wav_path)
    else:
        raise ValueError("Unsupported file format: {}".format(ext))

    if audio_len < 10:
        print(
            f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
        )

    assert audio_len >= 10

    mid_time = audio_len // 2
    start_time = mid_time - 5
    wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)

    wav = torch.tensor(wav).unsqueeze(0).to(model.device)

    with torch.no_grad():
        audio_emb = mulan(wavs=wav)  # [1, 512]

    audio_emb = audio_emb
    audio_emb = audio_emb.half()

    return audio_emb

def get_text_style_prompt(model, text_prompt):
    mulan = model
    
    with torch.no_grad():
        text_emb = mulan(texts = text_prompt) # [1, 512]
    text_emb = text_emb.half()

    return text_emb



def parse_lyrics(lyrics: str):
    lyrics_with_time = []
    lyrics = lyrics.strip()
    for line in lyrics.split('\n'):
        try:
            time, lyric = line[1:9], line[10:]
            lyric = lyric.strip()
            mins, secs = time.split(':')
            secs = int(mins) * 60 + float(secs)
            lyrics_with_time.append((secs, lyric))
        except:
            continue
    return lyrics_with_time

class CNENTokenizer():
    def __init__(self):
        #with open('./g2p/g2p/vocab.json', 'r') as file:
        with open("./g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
            self.phone2id:dict = json.load(file)['vocab']
        self.id2phone = {v:k for (k, v) in self.phone2id.items()}
        from g2p.g2p_generation import chn_eng_g2p
        self.tokenizer = chn_eng_g2p
    def encode(self, text):
        phone, token = self.tokenizer(text)
        token = [x+1 for x in token]
        return token
    def decode(self, token):
        return "|".join([self.id2phone[x-1] for x in token])
    
def get_lrc_token(max_frames, text, tokenizer, device):

    lyrics_shift = 0
    sampling_rate = 44100
    downsample_rate = 2048
    max_secs = max_frames / (sampling_rate / downsample_rate)
   
    pad_token_id = 0
    comma_token_id = 1
    period_token_id = 2    
    if text == "":
        return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()

    lrc_with_time = parse_lyrics(text)
    
    modified_lrc_with_time = []
    for i in range(len(lrc_with_time)):
        time, line = lrc_with_time[i]
        line_token = tokenizer.encode(line)
        modified_lrc_with_time.append((time, line_token))
    lrc_with_time = modified_lrc_with_time

    lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
    # lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
    
    normalized_start_time = 0.

    lrc = torch.zeros((max_frames,), dtype=torch.long)

    tokens_count = 0
    last_end_pos = 0
    for time_start, line in lrc_with_time:
        tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id]
        tokens = torch.tensor(tokens, dtype=torch.long)
        num_tokens = tokens.shape[0]

        gt_frame_start = int(time_start * sampling_rate / downsample_rate)
        
        frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift))
        
        frame_start = max(gt_frame_start - frame_shift, last_end_pos)
        frame_len = min(num_tokens, max_frames - frame_start)

       

        lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]

        tokens_count += num_tokens
        last_end_pos = frame_start + frame_len   
        
    lrc_emb = lrc.unsqueeze(0).to(device)
    
    normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
    normalized_start_time = normalized_start_time.half()
    
    return lrc_emb, normalized_start_time

def load_checkpoint(model, ckpt_path, device, use_ema=True):
    if device == "cuda":
        model = model.half()

    ckpt_type = ckpt_path.split(".")[-1]
    if ckpt_type == "safetensors":
        from safetensors.torch import load_file

        checkpoint = load_file(ckpt_path)
    else:
        checkpoint = torch.load(ckpt_path, weights_only=True)

    if use_ema:
        if ckpt_type == "safetensors":
            checkpoint = {"ema_model_state_dict": checkpoint}
        checkpoint["model_state_dict"] = {
            k.replace("ema_model.", ""): v
            for k, v in checkpoint["ema_model_state_dict"].items()
            if k not in ["initted", "step"]
        }
        model.load_state_dict(checkpoint["model_state_dict"], strict=False)
    else:
        if ckpt_type == "safetensors":
            checkpoint = {"model_state_dict": checkpoint}
        model.load_state_dict(checkpoint["model_state_dict"], strict=False)

    return model.to(device)

3 thoughts on “DiffRhythm – Fast, Full-Length Song Generation

  1. It sounds like these modifications could really streamline the process of generating outputs with timestamps and metadata. Adding multiple generations seems like a useful feature for testing different scenarios. Ensuring the correct model is loaded for specific lengths is crucial for accuracy. However, clarifying whether these changes have been thoroughly tested would be helpful. Have you considered documenting these updates for others who might use the modified scripts?

  2. Comment:
    The modifications to the inference scripts seem quite useful, especially with the addition of timestamps and meta info. Handling multiple generations and ensuring the correct model for the corresponding length is a great improvement. It’s good to clarify your non-affiliation with the project to avoid confusion. Could you elaborate on how the timestamps are being utilized in this context?

Leave a Reply to Tommyirorm Cancel reply

Your email address will not be published. Required fields are marked *