A look at DiffRhythm, a diffusion model for generative music. This one is FAST. I’m looking over the demos, sharing some installation notes, trying some demo generations, seeing what works and what doesn’t, and trying to make a decent sounding tune. This is not a detailed tutorial.
DiffRhythm Huggingface demo – https://huggingface.co/spaces/ASLP-lab/DiffRhythm
DiffRhythm demo page – https://aslp-lab.github.io/DiffRhythm.github.io/
DiffRhythm GitHub – https://github.com/ASLP-lab/DiffRhythm
Modified the inference scripts to add timestamps, writing meta info to .txt, added multiple generations, and (hopefully) loading the correct model for the correct length generation.
Please note: I am note affiliated with this project. Do not direct questions about my terribly written modifications to them. They have nothing to do with this.
Modified infer.py
# Copyright (c) 2025 ASLP-LAB
# 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
# 2025 Guobin Ma (guobin.ma@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
import datetime
import torch
import torchaudio
from einops import rearrange
print("Current working directory:", os.getcwd())
from infer_utils import (
decode_audio,
get_lrc_token,
get_negative_style_prompt,
get_reference_latent,
get_style_prompt,
prepare_model,
)
def inference(
cfm_model,
vae_model,
cond,
text,
duration,
style_prompt,
negative_style_prompt,
start_time,
chunked=False,
):
with torch.inference_mode():
generated, _ = cfm_model.sample(
cond=cond,
text=text,
duration=duration,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
steps=32,
cfg_strength=4.0,
start_time=start_time,
)
generated = generated.to(torch.float32)
latent = generated.transpose(1, 2) # [b d t]
output = decode_audio(latent, vae_model, chunked=chunked)
# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")
# Peak normalize, clip, convert to int16, and save to file
output = (
output.to(torch.float32)
.div(torch.max(torch.abs(output)))
.clamp(-1, 1)
.mul(32767)
.to(torch.int16)
.cpu()
)
return output
def generate_unique_filename():
"""Generate a unique filename with date and timestamp."""
now = datetime.datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
return f"output_{timestamp}"
def save_lyrics_metadata(lyrics_content, output_path):
"""Save lyrics to a text file with the same base filename."""
with open(output_path, "w", encoding="utf-8") as f:
f.write(lyrics_content)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--lrc-path",
type=str,
help="lyrics of target song",
) # lyrics of target song
parser.add_argument(
"--ref-prompt",
type=str,
help="reference prompt as style prompt for target song",
required=False,
) # reference prompt as style prompt for target song
parser.add_argument(
"--ref-audio-path",
type=str,
help="reference audio as style prompt for target song",
required=False,
) # reference audio as style prompt for target song
parser.add_argument(
"--chunked",
action="store_true",
help="whether to use chunked decoding",
) # whether to use chunked decoding
parser.add_argument(
"--audio-length",
type=int,
default=95,
choices=[95, 285],
help="length of generated song",
) # length of target song
parser.add_argument(
"--repo_id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model"
)
parser.add_argument(
"--output-dir",
type=str,
default="infer/example/output",
help="output directory fo generated song",
) # output directory of target song
parser.add_argument(
"--num-attempts",
type=int,
default=1,
help="number of generation attempts to run",
) # number of generation attempts
args = parser.parse_args()
assert (
args.ref_prompt or args.ref_audio_path
), "either ref_prompt or ref_audio_path should be provided"
assert not (
args.ref_prompt and args.ref_audio_path
), "only one of them should be provided"
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.mps.is_available():
device = "mps"
audio_length = args.audio_length
if audio_length == 95:
max_frames = 2048
elif audio_length == 285: # current not available
max_frames = 6144
cfm, tokenizer, muq, vae = prepare_model(max_frames, device, repo_id=args.repo_id)
if args.lrc_path:
with open(args.lrc_path, "r", encoding='utf-8') as f:
lrc = f.read()
else:
lrc = ""
# Create output directory if it doesn't exist
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
# Get additional metadata for saving
ref_source = args.ref_prompt if args.ref_prompt else args.ref_audio_path
metadata = (
f"Lyrics source: {args.lrc_path}\n"
f"Reference source: {ref_source}\n"
f"Audio length: {args.audio_length} seconds\n"
f"Model: {args.repo_id}\n"
f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
f"LYRICS:\n{lrc}"
)
# Run the generation N times based on num-attempts
for attempt in range(1, args.num_attempts + 1):
print(f"Starting generation attempt {attempt} of {args.num_attempts}")
# Get necessary components for generation
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
if args.ref_audio_path:
style_prompt = get_style_prompt(muq, args.ref_audio_path)
else:
style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
negative_style_prompt = get_negative_style_prompt(device)
latent_prompt = get_reference_latent(device, max_frames)
# Generate unique filename for this attempt
unique_filename = generate_unique_filename()
if args.num_attempts > 1:
unique_filename = f"{unique_filename}_attempt{attempt}"
# Run inference
s_t = time.time()
generated_song = inference(
cfm_model=cfm,
vae_model=vae,
cond=latent_prompt,
text=lrc_prompt,
duration=max_frames,
style_prompt=style_prompt,
negative_style_prompt=negative_style_prompt,
start_time=start_time,
chunked=args.chunked,
)
e_t = time.time() - s_t
print(f"Inference attempt {attempt} cost {e_t:.2f} seconds")
# Save audio file with unique filename
audio_output_path = os.path.join(output_dir, f"{unique_filename}.wav")
torchaudio.save(audio_output_path, generated_song, sample_rate=44100)
print(f"Saved audio to: {audio_output_path}")
# Save lyrics and metadata to corresponding text file
lyrics_output_path = os.path.join(output_dir, f"{unique_filename}.txt")
attempt_metadata = metadata + f"\n\nAttempt: {attempt} of {args.num_attempts}\nGeneration time: {e_t:.2f} seconds"
save_lyrics_metadata(attempt_metadata, lyrics_output_path)
print(f"Saved lyrics and metadata to: {lyrics_output_path}")
print(f"All {args.num_attempts} generation attempts completed successfully.")
Modified infer_utils.py
# Copyright (c) 2025 ASLP-LAB
# 2025 Huakang Chen (huakang@mail.nwpu.edu.cn)
# 2025 Guobin Ma (guobin.ma@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import librosa
import random
import json
from muq import MuQMuLan
from mutagen.mp3 import MP3
import os
import numpy as np
from huggingface_hub import hf_hub_download
from sys import path
path.append(os.getcwd())
from model import DiT, CFM
def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
downsampling_ratio = 2048
io_channels = 2
if not chunked:
return vae_model.decode_export(latents)
else:
# chunked decoding
hop_size = chunk_size - overlap
total_size = latents.shape[2]
batch_size = latents.shape[0]
chunks = []
i = 0
for i in range(0, total_size - chunk_size + 1, hop_size):
chunk = latents[:, :, i : i + chunk_size]
chunks.append(chunk)
if i + chunk_size != total_size:
# Final chunk
chunk = latents[:, :, -chunk_size:]
chunks.append(chunk)
chunks = torch.stack(chunks)
num_chunks = chunks.shape[0]
# samples_per_latent is just the downsampling ratio
samples_per_latent = downsampling_ratio
# Create an empty waveform, we will populate it with chunks as decode them
y_size = total_size * samples_per_latent
y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
for i in range(num_chunks):
x_chunk = chunks[i, :]
# decode the chunk
y_chunk = vae_model.decode_export(x_chunk)
# figure out where to put the audio along the time domain
if i == num_chunks - 1:
# final chunk always goes at the end
t_end = y_size
t_start = t_end - y_chunk.shape[2]
else:
t_start = i * hop_size * samples_per_latent
t_end = t_start + chunk_size * samples_per_latent
# remove the edges of the overlaps
ol = (overlap // 2) * samples_per_latent
chunk_start = 0
chunk_end = y_chunk.shape[2]
if i > 0:
# no overlap for the start of the first chunk
t_start += ol
chunk_start += ol
if i < num_chunks - 1:
# no overlap for the end of the last chunk
t_end -= ol
chunk_end -= ol
# paste the chunked audio into our y_final output audio
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
return y_final
def prepare_model(max_frames, device, repo_id="ASLP-lab/DiffRhythm-base"):
"""
Prepare the model based on the repository ID and max frames.
Args:
max_frames: Maximum number of frames (2048 for 95s, 6144 for 285s audio)
device: Device to load the model on ('cpu', 'cuda', 'mps')
repo_id: HuggingFace repository ID for the model
Returns:
A tuple of (cfm_model, tokenizer, muq, vae)
"""
# Determine appropriate repository ID based on max_frames
model_repo_id = repo_id
if max_frames > 2048 and repo_id == "ASLP-lab/DiffRhythm-base":
# If using longer audio with base model repo_id, switch to full model
model_repo_id = "ASLP-lab/DiffRhythm-full"
print(f"Using full model repository ({model_repo_id}) for {max_frames} frames")
# Prepare CFM model
dit_ckpt_path = hf_hub_download(
repo_id=model_repo_id, filename="cfm_model.pt", cache_dir="./pretrained"
)
dit_config_path = "./config/diffrhythm-1b.json"
with open(dit_config_path) as f:
model_config = json.load(f)
dit_model_cls = DiT
# Initialize the model using the original structure
cfm = CFM(
transformer=dit_model_cls(**model_config["model"], max_frames=max_frames),
num_channels=model_config["model"]["mel_dim"],
max_frames=max_frames
)
cfm = cfm.to(device)
cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
# Prepare tokenizer
tokenizer = CNENTokenizer()
# Prepare MuQ
muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
muq = muq.to(device).eval()
# Prepare VAE
vae_ckpt_path = hf_hub_download(
repo_id="ASLP-lab/DiffRhythm-vae",
filename="vae_model.pt",
cache_dir="./pretrained",
)
vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)
return cfm, tokenizer, muq, vae
# for song edit, will be added in the future
def get_reference_latent(device, max_frames):
return torch.zeros(1, max_frames, 64).to(device)
def get_negative_style_prompt(device):
file_path = "./infer/example/negative_prompt.npy"
vocal_stlye = np.load(file_path)
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) # [1, 512]
vocal_stlye = vocal_stlye.half()
return vocal_stlye
def get_audio_style_prompt(model, wav_path):
vocal_flag = False
mulan = model
audio, _ = librosa.load(wav_path, sr=24000)
audio_len = librosa.get_duration(y=audio, sr=24000)
if audio_len <= 1:
vocal_flag = True
if audio_len > 10:
start_time = int(audio_len // 2 - 5)
wav = audio[start_time*24000:(start_time+10)*24000]
else:
wav = audio
wav = torch.tensor(wav).unsqueeze(0).to(model.device)
with torch.no_grad():
audio_emb = mulan(wavs = wav) # [1, 512]
audio_emb = audio_emb.half()
return audio_emb, vocal_flag
@torch.no_grad()
def get_style_prompt(model, wav_path=None, prompt=None):
mulan = model
if prompt is not None:
return mulan(texts=prompt).half()
ext = os.path.splitext(wav_path)[-1].lower()
if ext == ".mp3":
meta = MP3(wav_path)
audio_len = meta.info.length
elif ext in [".wav", ".flac"]:
audio_len = librosa.get_duration(path=wav_path)
else:
raise ValueError("Unsupported file format: {}".format(ext))
if audio_len < 10:
print(
f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
)
assert audio_len >= 10
mid_time = audio_len // 2
start_time = mid_time - 5
wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
wav = torch.tensor(wav).unsqueeze(0).to(model.device)
with torch.no_grad():
audio_emb = mulan(wavs=wav) # [1, 512]
audio_emb = audio_emb
audio_emb = audio_emb.half()
return audio_emb
def get_text_style_prompt(model, text_prompt):
mulan = model
with torch.no_grad():
text_emb = mulan(texts = text_prompt) # [1, 512]
text_emb = text_emb.half()
return text_emb
def parse_lyrics(lyrics: str):
lyrics_with_time = []
lyrics = lyrics.strip()
for line in lyrics.split('\n'):
try:
time, lyric = line[1:9], line[10:]
lyric = lyric.strip()
mins, secs = time.split(':')
secs = int(mins) * 60 + float(secs)
lyrics_with_time.append((secs, lyric))
except:
continue
return lyrics_with_time
class CNENTokenizer():
def __init__(self):
#with open('./g2p/g2p/vocab.json', 'r') as file:
with open("./g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
self.phone2id:dict = json.load(file)['vocab']
self.id2phone = {v:k for (k, v) in self.phone2id.items()}
from g2p.g2p_generation import chn_eng_g2p
self.tokenizer = chn_eng_g2p
def encode(self, text):
phone, token = self.tokenizer(text)
token = [x+1 for x in token]
return token
def decode(self, token):
return "|".join([self.id2phone[x-1] for x in token])
def get_lrc_token(max_frames, text, tokenizer, device):
lyrics_shift = 0
sampling_rate = 44100
downsample_rate = 2048
max_secs = max_frames / (sampling_rate / downsample_rate)
pad_token_id = 0
comma_token_id = 1
period_token_id = 2
if text == "":
return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
lrc_with_time = parse_lyrics(text)
modified_lrc_with_time = []
for i in range(len(lrc_with_time)):
time, line = lrc_with_time[i]
line_token = tokenizer.encode(line)
modified_lrc_with_time.append((time, line_token))
lrc_with_time = modified_lrc_with_time
lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs]
# lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
normalized_start_time = 0.
lrc = torch.zeros((max_frames,), dtype=torch.long)
tokens_count = 0
last_end_pos = 0
for time_start, line in lrc_with_time:
tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id]
tokens = torch.tensor(tokens, dtype=torch.long)
num_tokens = tokens.shape[0]
gt_frame_start = int(time_start * sampling_rate / downsample_rate)
frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift))
frame_start = max(gt_frame_start - frame_shift, last_end_pos)
frame_len = min(num_tokens, max_frames - frame_start)
lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
tokens_count += num_tokens
last_end_pos = frame_start + frame_len
lrc_emb = lrc.unsqueeze(0).to(device)
normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
normalized_start_time = normalized_start_time.half()
return lrc_emb, normalized_start_time
def load_checkpoint(model, ckpt_path, device, use_ema=True):
if device == "cuda":
model = model.half()
ckpt_type = ckpt_path.split(".")[-1]
if ckpt_type == "safetensors":
from safetensors.torch import load_file
checkpoint = load_file(ckpt_path)
else:
checkpoint = torch.load(ckpt_path, weights_only=True)
if use_ema:
if ckpt_type == "safetensors":
checkpoint = {"ema_model_state_dict": checkpoint}
checkpoint["model_state_dict"] = {
k.replace("ema_model.", ""): v
for k, v in checkpoint["ema_model_state_dict"].items()
if k not in ["initted", "step"]
}
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
else:
if ckpt_type == "safetensors":
checkpoint = {"model_state_dict": checkpoint}
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
return model.to(device)
It sounds like these modifications could really streamline the process of generating outputs with timestamps and metadata. Adding multiple generations seems like a useful feature for testing different scenarios. Ensuring the correct model is loaded for specific lengths is crucial for accuracy. However, clarifying whether these changes have been thoroughly tested would be helpful. Have you considered documenting these updates for others who might use the modified scripts?
Comment:
The modifications to the inference scripts seem quite useful, especially with the addition of timestamps and meta info. Handling multiple generations and ensuring the correct model for the corresponding length is a great improvement. It’s good to clarify your non-affiliation with the project to avoid confusion. Could you elaborate on how the timestamps are being utilized in this context?
hi