语音识别:使用torchaudio快速实现音频特征提取

jupiter
2022-01-08 / 0 评论 / 1,258 阅读 / 正在检测是否收录...
温馨提示:
本文最后更新于2022年01月08日,已超过839天没有更新,若内容或图片失效,请留言反馈。

1.fbank特征

import torch.nn as nn
import torchaudio
class ExtractAudioFeature(nn.Module):
    def __init__(self, feat_type="fbank", feat_dim=40):
        super(ExtractAudioFeature, self).__init__()
        self.feat_type = feat_type
        self.extract_fn = torchaudio.compliance.kaldi.fbank if feat_type == "fbank" else torchaudio.compliance.kaldi.mfcc
        self.num_mel_bins = feat_dim

    def forward(self, filepath):
        waveform, sample_rate = torchaudio.load(filepath)
        y = self.extract_fn(waveform,
                            num_mel_bins=self.num_mel_bins,
                            channel=-1,
                            sample_frequency=sample_rate,
                            frame_length=25, #每帧的时长
                            frame_shift=10,
                            dither=0)
        return y.transpose(0, 1).unsqueeze(0).detach()
extracter = ExtractAudioFeature("fbank",feat_dim=40)
wav = "./data/wav/day0914_990.wav"
wav_feature = extracter(wav)
print(wav_feature.shape)
torch.Size([1, 40, 489])
# 40:特征维度
# 489:音频帧数=音频时长/25ms
  • 查看图示
import matplotlib.pyplot as plt
plt.figure(dpi=200)
plt.xticks([])
plt.yticks([])
plt.imshow(wav_feature[0])
plt.show()

2.mfcc特征

import torch.nn as nn
import torchaudio
class ExtractAudioFeature(nn.Module):
    def __init__(self, feat_type="mfcc", feat_dim=13):
        super(ExtractAudioFeature, self).__init__()
        self.feat_type = feat_type
        self.extract_fn = torchaudio.compliance.kaldi.fbank if feat_type == "fbank" else torchaudio.compliance.kaldi.mfcc
        self.num_mel_bins = feat_dim

    def forward(self, filepath):
        waveform, sample_rate = torchaudio.load(filepath)
        y = self.extract_fn(waveform,
                            num_mel_bins=self.num_mel_bins,
                            channel=-1,
                            sample_frequency=sample_rate,
                            frame_length=25, #每帧的时长
                            frame_shift=10,
                            dither=0)
        return y.transpose(0, 1).unsqueeze(0).detach()
extracter = ExtractAudioFeature("mfcc",feat_dim=13)
wav = "./data/wav/day0914_990.wav"
wav_feature = extracter(wav)
print(wav_feature.shape)
torch.Size([1, 13, 489])
# 13:特征维度
# 489:音频帧数=音频时长/25ms
  • 查看图示
import matplotlib.pyplot as plt
plt.figure(dpi=200)
plt.xticks([])
plt.yticks([])
plt.imshow(wav_feature[0])
plt.show()

参考资料

  1. https://github.com/neil-zeng/asr
0

评论 (0)

打卡
取消