/*
 * Decompiled with CFR 0.152.
 */
package marytts.signalproc.adaptation.codebook;

import java.io.IOException;
import javax.sound.sampled.UnsupportedAudioFileException;
import marytts.signalproc.adaptation.AdaptationUtils;
import marytts.signalproc.adaptation.BaselineAdaptationSet;
import marytts.signalproc.adaptation.BaselineFeatureCollection;
import marytts.signalproc.adaptation.BaselineFeatureExtractor;
import marytts.signalproc.adaptation.BaselinePreprocessor;
import marytts.signalproc.adaptation.BaselineTrainer;
import marytts.signalproc.adaptation.IndexMap;
import marytts.signalproc.adaptation.codebook.WeightedCodebookFeatureCollection;
import marytts.signalproc.adaptation.codebook.WeightedCodebookFeatureMapper;
import marytts.signalproc.adaptation.codebook.WeightedCodebookFile;
import marytts.signalproc.adaptation.codebook.WeightedCodebookFileHeader;
import marytts.signalproc.adaptation.codebook.WeightedCodebookLsfMapper;
import marytts.signalproc.adaptation.codebook.WeightedCodebookMfccMapper;
import marytts.signalproc.adaptation.codebook.WeightedCodebookOutlierEliminator;
import marytts.signalproc.adaptation.codebook.WeightedCodebookTrainerParams;
import marytts.signalproc.adaptation.prosody.PitchMappingFile;
import marytts.signalproc.adaptation.prosody.PitchTrainer;
import marytts.util.io.FileUtils;
import marytts.util.string.StringUtils;

public class WeightedCodebookTrainer
extends BaselineTrainer {
    public WeightedCodebookTrainerParams wcParams;
    public WeightedCodebookOutlierEliminator outlierEliminator;

    public WeightedCodebookTrainer(BaselinePreprocessor pp, BaselineFeatureExtractor fe, WeightedCodebookTrainerParams pa) {
        super(pp, fe);
        this.wcParams = new WeightedCodebookTrainerParams(pa);
        this.outlierEliminator = new WeightedCodebookOutlierEliminator();
    }

    public void run() throws IOException, UnsupportedAudioFileException {
        if (this.checkParams()) {
            BaselineAdaptationSet sourceTrainingSet = new BaselineAdaptationSet(this.wcParams.sourceTrainingFolder);
            BaselineAdaptationSet targetTrainingSet = new BaselineAdaptationSet(this.wcParams.targetTrainingFolder);
            int[] map = this.getIndexedMapping(sourceTrainingSet, targetTrainingSet);
            this.train(sourceTrainingSet, targetTrainingSet, map);
        }
    }

    @Override
    public boolean checkParams() {
        boolean bContinue = true;
        this.wcParams.trainingBaseFolder = StringUtils.checkLastSlash(this.wcParams.trainingBaseFolder);
        this.wcParams.sourceTrainingFolder = StringUtils.checkLastSlash(this.wcParams.sourceTrainingFolder);
        this.wcParams.targetTrainingFolder = StringUtils.checkLastSlash(this.wcParams.targetTrainingFolder);
        FileUtils.createDirectory(this.wcParams.trainingBaseFolder);
        if (!FileUtils.exists(this.wcParams.trainingBaseFolder) || !FileUtils.isDirectory(this.wcParams.trainingBaseFolder)) {
            System.out.println("Error! Training base folder " + this.wcParams.trainingBaseFolder + " not found.");
            bContinue = false;
        }
        if (!FileUtils.exists(this.wcParams.sourceTrainingFolder) || !FileUtils.isDirectory(this.wcParams.sourceTrainingFolder)) {
            System.out.println("Error! Source training folder " + this.wcParams.sourceTrainingFolder + " not found.");
            bContinue = false;
        }
        if (!FileUtils.exists(this.wcParams.targetTrainingFolder) || !FileUtils.isDirectory(this.wcParams.targetTrainingFolder)) {
            System.out.println("Error! Target training folder " + this.wcParams.targetTrainingFolder + " not found.");
            bContinue = false;
        }
        this.wcParams.temporaryCodebookFile = String.valueOf(this.wcParams.codebookFile) + ".temp";
        return bContinue;
    }

    public void train(BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException, UnsupportedAudioFileException {
        if (sourceTrainingSet.items != null && targetTrainingSet.items != null && map != null) {
            if (sourceTrainingSet.items.length != targetTrainingSet.items.length || sourceTrainingSet.items.length != map.length) {
                throw new RuntimeException("Lengths of source, target and map must be the same");
            }
            int numItems = sourceTrainingSet.items.length;
            if (numItems > 0) {
                this.preprocessor.run(sourceTrainingSet);
                this.preprocessor.run(targetTrainingSet);
                int desiredFeatures = this.wcParams.codebookHeader.vocalTractFeature + BaselineFeatureExtractor.F0_FEATURES + BaselineFeatureExtractor.ENERGY_FEATURES;
                this.featureExtractor.run(sourceTrainingSet, this.wcParams, desiredFeatures);
                this.featureExtractor.run(targetTrainingSet, this.wcParams, desiredFeatures);
            }
            WeightedCodebookFeatureCollection fcol = this.collectFeatures(sourceTrainingSet, targetTrainingSet, map);
            this.learnMapping(fcol, sourceTrainingSet, targetTrainingSet, map);
            this.outlierEliminator.run(this.wcParams);
            this.deleteTemporaryFiles(fcol, sourceTrainingSet, targetTrainingSet);
        }
    }

    public WeightedCodebookFeatureCollection collectFeatures(BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
        WeightedCodebookFeatureCollection fcol;
        block30: {
            IndexMap imap;
            block33: {
                block32: {
                    block31: {
                        block29: {
                            fcol = new WeightedCodebookFeatureCollection(this.wcParams, map.length);
                            imap = null;
                            if (this.wcParams.codebookHeader.codebookType != WeightedCodebookFileHeader.FRAMES) break block29;
                            int i = 0;
                            while (i < map.length) {
                                if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
                                    imap = AdaptationUtils.mapFramesFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile, targetTrainingSet.items[map[i]].lsfFile, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                                } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
                                    imap = AdaptationUtils.mapFramesFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile, targetTrainingSet.items[map[i]].mfccFile, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                                }
                                try {
                                    imap.writeToFile(fcol.indexMapFiles[i]);
                                }
                                catch (IOException e) {
                                    e.printStackTrace();
                                }
                                ++i;
                            }
                            break block30;
                        }
                        if (this.wcParams.codebookHeader.codebookType != WeightedCodebookFileHeader.FRAME_GROUPS) break block31;
                        int i = 0;
                        while (i < map.length) {
                            if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
                                imap = AdaptationUtils.mapFrameGroupsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile, targetTrainingSet.items[map[i]].lsfFile, this.wcParams.codebookHeader.numNeighboursInFrameGroups, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                            } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
                                imap = AdaptationUtils.mapFrameGroupsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile, targetTrainingSet.items[map[i]].mfccFile, this.wcParams.codebookHeader.numNeighboursInFrameGroups, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                            }
                            try {
                                imap.writeToFile(fcol.indexMapFiles[i]);
                            }
                            catch (IOException e) {
                                e.printStackTrace();
                            }
                            ++i;
                        }
                        break block30;
                    }
                    if (this.wcParams.codebookHeader.codebookType != WeightedCodebookFileHeader.LABELS) break block32;
                    int i = 0;
                    while (i < map.length) {
                        if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
                            imap = AdaptationUtils.mapLabelsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile, targetTrainingSet.items[map[i]].lsfFile, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                        } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
                            imap = AdaptationUtils.mapLabelsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile, targetTrainingSet.items[map[i]].mfccFile, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                        }
                        try {
                            imap.writeToFile(fcol.indexMapFiles[i]);
                        }
                        catch (IOException e) {
                            e.printStackTrace();
                        }
                        ++i;
                    }
                    break block30;
                }
                if (this.wcParams.codebookHeader.codebookType != WeightedCodebookFileHeader.LABEL_GROUPS) break block33;
                int i = 0;
                while (i < map.length) {
                    if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
                        imap = AdaptationUtils.mapLabelGroupsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].lsfFile, targetTrainingSet.items[map[i]].lsfFile, this.wcParams.codebookHeader.numNeighboursInLabelGroups, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                    } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
                        imap = AdaptationUtils.mapLabelGroupsFeatures(sourceTrainingSet.items[i].labelFile, targetTrainingSet.items[map[i]].labelFile, sourceTrainingSet.items[i].mfccFile, targetTrainingSet.items[map[i]].mfccFile, this.wcParams.codebookHeader.numNeighboursInLabelGroups, this.wcParams.codebookHeader.vocalTractFeature, this.wcParams.labelsToExcludeFromTraining);
                    }
                    try {
                        imap.writeToFile(fcol.indexMapFiles[i]);
                    }
                    catch (IOException e) {
                        e.printStackTrace();
                    }
                    ++i;
                }
                break block30;
            }
            if (this.wcParams.codebookHeader.codebookType != WeightedCodebookFileHeader.SPEECH) break block30;
            if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
                imap = AdaptationUtils.mapSpeechFeatures();
            } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
                imap = AdaptationUtils.mapSpeechFeatures();
            }
            try {
                imap.writeToFile(fcol.indexMapFiles[0]);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        return fcol;
    }

    public void learnMapping(BaselineFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
        assert (fcol instanceof WeightedCodebookFeatureCollection);
        this.learnMapping((WeightedCodebookFeatureCollection)fcol, sourceTrainingSet, targetTrainingSet, map);
    }

    public void learnMapping(WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet, int[] map) throws IOException {
        WeightedCodebookFeatureMapper featureMapper = null;
        if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.LSF_FEATURES) {
            featureMapper = new WeightedCodebookLsfMapper(this.wcParams);
        } else if (this.wcParams.codebookHeader.vocalTractFeature == BaselineFeatureExtractor.MFCC_FEATURES_FROM_FILES) {
            featureMapper = new WeightedCodebookMfccMapper(this.wcParams);
        }
        if (featureMapper != null) {
            WeightedCodebookFile temporaryCodebookFile = new WeightedCodebookFile(this.wcParams.temporaryCodebookFile, WeightedCodebookFile.OPEN_FOR_WRITE);
            PitchMappingFile pitchMappingFile = new PitchMappingFile(this.wcParams.pitchMappingFile, PitchMappingFile.OPEN_FOR_WRITE);
            if (this.wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAMES) {
                featureMapper.learnMappingFrames(temporaryCodebookFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            } else if (this.wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.FRAME_GROUPS) {
                featureMapper.learnMappingFrameGroups(temporaryCodebookFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            } else if (this.wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABELS) {
                featureMapper.learnMappingLabels(temporaryCodebookFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            } else if (this.wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.LABEL_GROUPS) {
                featureMapper.learnMappingLabelGroups(temporaryCodebookFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            } else if (this.wcParams.codebookHeader.codebookType == WeightedCodebookFileHeader.SPEECH) {
                featureMapper.learnMappingSpeech(temporaryCodebookFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            }
            temporaryCodebookFile.close();
            PitchTrainer ptcTrainer = new PitchTrainer(this.wcParams);
            ptcTrainer.learnMapping(pitchMappingFile, fcol, sourceTrainingSet, targetTrainingSet, map);
            pitchMappingFile.close();
        } else {
            System.out.println("Error! Specified feature mapper does not exist...");
        }
    }

    public void deleteTemporaryFiles(WeightedCodebookFeatureCollection fcol, BaselineAdaptationSet sourceTrainingSet, BaselineAdaptationSet targetTrainingSet) {
        FileUtils.delete(fcol.indexMapFiles, true);
        FileUtils.delete(this.wcParams.temporaryCodebookFile);
    }
}

