/*
 * Decompiled with CFR 0.152.
 */
package marytts.unitselection.select;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.FloatBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Vector;
import marytts.exceptions.MaryConfigurationException;
import marytts.features.MaryGenericFeatureProcessors;
import marytts.modules.phonemiser.Allophone;
import marytts.server.MaryProperties;
import marytts.signalproc.display.Histogram;
import marytts.unitselection.data.DiphoneUnit;
import marytts.unitselection.data.Unit;
import marytts.unitselection.select.JoinCostFunction;
import marytts.unitselection.select.PrecompiledJoinCostReader;
import marytts.unitselection.select.Target;
import marytts.unitselection.weightingfunctions.WeightFunc;
import marytts.unitselection.weightingfunctions.WeightFunctionManager;
import marytts.util.MaryUtils;
import marytts.util.data.MaryHeader;
import marytts.util.io.StreamUtils;

public class JoinCostFeatures
implements JoinCostFunction {
    protected float wSignal;
    protected float wPhonetic;
    protected boolean debugShowCostGraph = false;
    protected double[] cumulWeightedSignalCosts = null;
    protected int nCostComputations = 0;
    protected PrecompiledJoinCostReader precompiledCosts;
    protected JoinCostReporter jcr;
    private MaryHeader hdr = null;
    private float[] featureWeight = null;
    private WeightFunc[] weightFunction = null;
    private boolean[] isLinear = null;
    private float[][] leftJCF = null;
    private float[][] rightJCF = null;

    public JoinCostFeatures() {
    }

    public JoinCostFeatures(String fileName) throws IOException, MaryConfigurationException {
        this.load(fileName, null, null, 0.5f);
    }

    @Override
    public void init(String configPrefix) throws MaryConfigurationException {
        String joinFileName = MaryProperties.needFilename(String.valueOf(configPrefix) + ".joinCostFile");
        String precomputedJoinCostFileName = MaryProperties.getFilename(String.valueOf(configPrefix) + ".precomputedJoinCostFile");
        float wSignal = Float.parseFloat(MaryProperties.getProperty(String.valueOf(configPrefix) + ".joincostfunction.wSignal", "1.0"));
        try {
            InputStream joinWeightStream = MaryProperties.getStream(String.valueOf(configPrefix) + ".joinCostWeights");
            this.load(joinFileName, joinWeightStream, precomputedJoinCostFileName, wSignal);
        }
        catch (IOException ioe) {
            throw new MaryConfigurationException("Problem loading join file " + joinFileName, ioe);
        }
    }

    @Override
    public void load(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal) throws IOException, MaryConfigurationException {
        this.loadFromByteBuffer(joinFileName, weightStream, precompiledCostFileName, wSignal);
    }

    private void loadFromByteBuffer(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal) throws IOException, MaryConfigurationException {
        if (precompiledCostFileName != null) {
            this.precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
        }
        this.wSignal = wSignal;
        this.wPhonetic = 1.0f - wSignal;
        FileInputStream fis = new FileInputStream(joinFileName);
        FileChannel fc = fis.getChannel();
        MappedByteBuffer bb = fc.map(FileChannel.MapMode.READ_ONLY, 0L, fc.size());
        this.hdr = new MaryHeader(bb);
        if (this.hdr.getType() != 400) {
            throw new IOException("File [" + joinFileName + "] is not a valid Mary join features file.");
        }
        try {
            int i;
            int numberOfFeatures = bb.getInt();
            this.featureWeight = new float[numberOfFeatures];
            this.weightFunction = new WeightFunc[numberOfFeatures];
            this.isLinear = new boolean[numberOfFeatures];
            WeightFunctionManager wfm = new WeightFunctionManager();
            String wfStr = null;
            int i2 = 0;
            while (i2 < numberOfFeatures) {
                this.featureWeight[i2] = bb.getFloat();
                wfStr = StreamUtils.readUTF(bb);
                this.weightFunction[i2] = "".equals(wfStr) ? wfm.getWeightFunction("linear") : wfm.getWeightFunction(wfStr);
                ++i2;
            }
            if (weightStream != null) {
                MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
                Object[] weightData = JoinCostFeatures.readJoinCostWeightsStream(weightStream);
                this.featureWeight = (float[])weightData[0];
                String[] wf = (String[])weightData[1];
                if (this.featureWeight.length != numberOfFeatures) {
                    throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures + " features, but weight file contains " + this.featureWeight.length + " feature weights!");
                }
                i = 0;
                while (i < numberOfFeatures) {
                    this.weightFunction[i] = wfm.getWeightFunction(wf[i]);
                    ++i;
                }
            }
            i = 0;
            while (i < numberOfFeatures) {
                this.isLinear[i] = this.weightFunction[i].whoAmI().equals("linear");
                ++i;
            }
            int numberOfUnits = bb.getInt();
            FloatBuffer fb = bb.asFloatBuffer();
            this.leftJCF = new float[numberOfUnits][];
            this.rightJCF = new float[numberOfUnits][];
            i = 0;
            while (i < numberOfUnits) {
                this.leftJCF[i] = new float[numberOfFeatures];
                fb.get(this.leftJCF[i]);
                this.rightJCF[i] = new float[numberOfFeatures];
                fb.get(this.rightJCF[i]);
                ++i;
            }
        }
        catch (EOFException e) {
            IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
            ioe.initCause(e);
            throw ioe;
        }
        if (MaryProperties.getBoolean("debug.show.cost.graph")) {
            this.debugShowCostGraph = true;
            this.cumulWeightedSignalCosts = new double[this.featureWeight.length];
            this.jcr = new JoinCostReporter(this.cumulWeightedSignalCosts);
            this.jcr.showInJFrame("Average signal join costs", false, false);
            this.jcr.start();
        }
    }

    private void loadFromStream(String joinFileName, InputStream weightStream, String precompiledCostFileName, float wSignal) throws IOException, MaryConfigurationException {
        if (precompiledCostFileName != null) {
            this.precompiledCosts = new PrecompiledJoinCostReader(precompiledCostFileName);
        }
        this.wSignal = wSignal;
        this.wPhonetic = 1.0f - wSignal;
        File fid = new File(joinFileName);
        DataInputStream raf = new DataInputStream(new BufferedInputStream(new FileInputStream(fid)));
        this.hdr = new MaryHeader(raf);
        if (this.hdr.getType() != 400) {
            throw new MaryConfigurationException("File [" + joinFileName + "] is not a valid Mary join features file.");
        }
        try {
            int numberOfFeatures = raf.readInt();
            this.featureWeight = new float[numberOfFeatures];
            this.weightFunction = new WeightFunc[numberOfFeatures];
            this.isLinear = new boolean[numberOfFeatures];
            WeightFunctionManager wfm = new WeightFunctionManager();
            String wfStr = null;
            int i = 0;
            while (i < numberOfFeatures) {
                this.featureWeight[i] = raf.readFloat();
                wfStr = raf.readUTF();
                this.weightFunction[i] = "".equals(wfStr) ? wfm.getWeightFunction("linear") : wfm.getWeightFunction(wfStr);
                ++i;
            }
            if (weightStream != null) {
                MaryUtils.getLogger("JoinCostFeatures").debug("Overwriting join cost weights");
                Object[] weightData = JoinCostFeatures.readJoinCostWeightsStream(weightStream);
                this.featureWeight = (float[])weightData[0];
                String[] wf = (String[])weightData[1];
                if (this.featureWeight.length != numberOfFeatures) {
                    throw new IllegalArgumentException("Join cost file contains " + numberOfFeatures + " features, but weight file contains " + this.featureWeight.length + " feature weights!");
                }
                int i2 = 0;
                while (i2 < numberOfFeatures) {
                    this.weightFunction[i2] = wfm.getWeightFunction(wf[i2]);
                    ++i2;
                }
            }
            i = 0;
            while (i < numberOfFeatures) {
                this.isLinear[i] = this.weightFunction[i].whoAmI().equals("linear");
                ++i;
            }
            int numberOfUnits = raf.readInt();
            this.leftJCF = new float[numberOfUnits][];
            this.rightJCF = new float[numberOfUnits][];
            int i3 = 0;
            while (i3 < numberOfUnits) {
                this.leftJCF[i3] = new float[numberOfFeatures];
                int j = 0;
                while (j < numberOfFeatures) {
                    this.leftJCF[i3][j] = raf.readFloat();
                    ++j;
                }
                this.rightJCF[i3] = new float[numberOfFeatures];
                j = 0;
                while (j < numberOfFeatures) {
                    this.rightJCF[i3][j] = raf.readFloat();
                    ++j;
                }
                ++i3;
            }
        }
        catch (EOFException e) {
            IOException ioe = new IOException("The currently read Join Cost File has prematurely reached EOF.");
            ioe.initCause(e);
            throw ioe;
        }
        if (MaryProperties.getBoolean("debug.show.cost.graph")) {
            this.debugShowCostGraph = true;
            this.cumulWeightedSignalCosts = new double[this.featureWeight.length];
            this.jcr = new JoinCostReporter(this.cumulWeightedSignalCosts);
            this.jcr.showInJFrame("Average signal join costs", false, false);
            this.jcr.start();
        }
    }

    public static Object[] readJoinCostWeightsFile(String fileName) throws IOException, FileNotFoundException {
        return JoinCostFeatures.readJoinCostWeightsStream(new FileInputStream(fileName));
    }

    public static Object[] readJoinCostWeightsStream(InputStream weightStream) throws IOException, FileNotFoundException {
        Vector<Float> v = new Vector<Float>(16, 16);
        Vector<String> vf = new Vector<String>(16, 16);
        BufferedReader in = new BufferedReader(new InputStreamReader(weightStream, "UTF-8"));
        String line = null;
        String[] fields = null;
        float sumOfWeights = 0.0f;
        while ((line = in.readLine()) != null) {
            line = line.split("#", 2)[0];
            if ((line = line.trim()).equals("")) continue;
            line = line.split(":", 2)[1].trim();
            fields = line.split("\\s", 2);
            float aWeight = Float.parseFloat(fields[0]);
            sumOfWeights += aWeight;
            v.add(new Float(aWeight));
            vf.add(fields[1]);
        }
        in.close();
        String[] wfun = vf.toArray(new String[vf.size()]);
        float[] fw = new float[v.size()];
        int i = 0;
        while (i < fw.length) {
            Float aWeight = (Float)v.get(i);
            fw[i] = aWeight.floatValue() / sumOfWeights;
            ++i;
        }
        return new Object[]{fw, wfun};
    }

    public int getNumberOfFeatures() {
        return this.featureWeight.length;
    }

    public int getNumberOfUnits() {
        return this.leftJCF.length;
    }

    public float[] getLeftJCF(int u) {
        if (u < 0) {
            throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
        }
        if (u > this.getNumberOfUnits()) {
            throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + this.getNumberOfUnits() + "] units.");
        }
        return this.leftJCF[u];
    }

    public float[] getRightJCF(int u) {
        if (u < 0) {
            throw new RuntimeException("The unit index [" + u + "] is out of range: a unit index can't be negative.");
        }
        if (u > this.getNumberOfUnits()) {
            throw new RuntimeException("The unit index [" + u + "] is out of range: this file contains [" + this.getNumberOfUnits() + "] units.");
        }
        return this.rightJCF[u];
    }

    public double cost(int u1, int u2) {
        if (u1 < 0) {
            throw new RuntimeException("The left unit index [" + u1 + "] is out of range: a unit index can't be negative.");
        }
        if (u1 > this.leftJCF.length) {
            throw new RuntimeException("The left unit index [" + u1 + "] is out of range: this file contains [" + this.getNumberOfUnits() + "] units.");
        }
        if (u2 < 0) {
            throw new RuntimeException("The right unit index [" + u2 + "] is out of range: a unit index can't be negative.");
        }
        if (u2 > this.leftJCF.length) {
            throw new RuntimeException("The right unit index [" + u2 + "] is out of range: this file contains [" + this.getNumberOfUnits() + "] units.");
        }
        if (this.debugShowCostGraph) {
            this.jcr.tick();
        }
        double res = 0.0;
        float[] v1 = this.rightJCF[u1];
        float[] v2 = this.leftJCF[u2];
        int i = 0;
        while (i < v1.length) {
            float a = v1[i];
            float b = v2[i];
            if (a == a && b == b) {
                double c = this.isLinear[i] ? (double)(this.featureWeight[i] * (a > b ? a - b : b - a)) : (double)this.featureWeight[i] * this.weightFunction[i].cost(a, b);
                res += c;
                if (this.debugShowCostGraph) {
                    int n = i;
                    this.cumulWeightedSignalCosts[n] = this.cumulWeightedSignalCosts[n] + (double)this.wSignal * c;
                }
            }
            ++i;
        }
        return res;
    }

    @Override
    public double cost(Target t1, Unit u1, Target t2, Unit u2) {
        if (u1.duration == 0 || u2.duration == 0) {
            return Double.POSITIVE_INFINITY;
        }
        boolean bothDiphones = true;
        if (u1 instanceof DiphoneUnit) {
            u1 = ((DiphoneUnit)u1).right;
        } else {
            bothDiphones = false;
        }
        if (u2 instanceof DiphoneUnit) {
            u2 = ((DiphoneUnit)u2).left;
        } else {
            bothDiphones = false;
        }
        if (u1.index + 1 == u2.index) {
            return 0.0;
        }
        double cost = 1.0;
        cost = bothDiphones && this.precompiledCosts != null ? (cost += this.precompiledCosts.cost(t1, u1, t2, u2)) : (cost += this.cost(u1.index, u2.index));
        return cost;
    }

    protected double cost(Target t1, Target t2) {
        boolean stressed2;
        double cost = 0.0;
        MaryGenericFeatureProcessors.Stressed stressProcessor = new MaryGenericFeatureProcessors.Stressed("", new MaryGenericFeatureProcessors.SyllableNavigator());
        boolean stressed1 = stressProcessor.process(t1) == 1;
        boolean bl = stressed2 = stressProcessor.process(t1) == 1;
        if (stressed1 || stressed2) {
            cost += 0.2;
        }
        Allophone p1 = t1.getAllophone();
        Allophone p2 = t2.getAllophone();
        if (p1.isVowel() || p2.isVowel()) {
            cost += 0.2;
        }
        if (p1.isGlide() || p2.isGlide()) {
            cost += 0.2;
        }
        if (p1.isVoiced() || p2.isVoiced()) {
            cost += 0.1;
        }
        if (p1.isVoiced() && p2.isVoiced()) {
            cost += 0.1;
        }
        if (p1.isNasal() || p2.isNasal()) {
            cost += 0.05;
        }
        if (p1.isLiquid() || p2.isLiquid()) {
            cost += 0.05;
        }
        if (cost > 1.0) {
            cost = 1.0;
        }
        return cost;
    }

    public static class JoinCostReporter
    extends Histogram {
        private double[] data;
        private int lastN = 0;
        private int nCostComputations = 0;

        public JoinCostReporter(double[] data) {
            super(0.0, 1.0, data);
            this.data = data;
        }

        public void start() {
            new Thread(){

                @Override
                public void run() {
                    while (JoinCostReporter.this.isVisible()) {
                        try {
                            Thread.sleep(500L);
                        }
                        catch (InterruptedException interruptedException) {}
                        JoinCostReporter.this.updateGraph();
                    }
                }
            }.start();
        }

        public void tick() {
            ++this.nCostComputations;
        }

        protected void updateGraph() {
            if (this.nCostComputations == this.lastN) {
                return;
            }
            this.lastN = this.nCostComputations;
            double[] newCosts = new double[this.data.length];
            int i = 0;
            while (i < newCosts.length) {
                newCosts[i] = this.data[i] / (double)this.nCostComputations;
                ++i;
            }
            this.updateData(0.0, 1.0, newCosts);
            this.repaint();
        }
    }
}

