package jbcl.calc.alignment;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.logging.Logger;
import jbcl.calc.numeric.algebra.EigenvalueDecomposition;
import jbcl.calc.numeric.algebra.Matrix;
import jbcl.calc.structural.properties.MinDistance;
import jbcl.data.basic.DataTable;
import jbcl.data.basic.Tuple;
import jbcl.data.basic.TwoTuple;
import jbcl.data.dict.Monomer;
import jbcl.data.dict.MonomersFactory;
import jbcl.data.formats.FlatSequenceProfile;
import jbcl.data.formats.PDB;
import jbcl.data.formats.alignments.ClustalW;
import jbcl.data.types.Residue;
import jbcl.data.types.SequenceProfile;
import jbcl.util.GZipAwareBufferedReader;

/* loaded from: input_file:jbcl/calc/alignment/CorrelatedMutations.class */
public class CorrelatedMutations extends WeightedSequenceProfile {
    private SequenceWeighting recentWeightingScheme;
    private SequenceProfile sequenceProfile;
    private byte[][] msaAsBytes;
    private static final Logger jbcl_logger = Logger.getLogger(CorrelatedMutations.class.getCanonicalName());

    /* loaded from: input_file:jbcl/calc/alignment/CorrelatedMutations$JointAADistribution.class */
    public class JointAADistribution {
        public final int posI;
        public final int posJ;
        public final double[][] observedCounts = new double[21][21];
        public final double[][] correctedCounts = new double[21][21];
        private EigenvalueDecomposition eig = null;

        public JointAADistribution(int i, int i2, SequenceWeighting sequenceWeighting) {
            if (sequenceWeighting != CorrelatedMutations.this.recentWeightingScheme) {
                CorrelatedMutations.this.sequenceProfile = CorrelatedMutations.this.computeProfile(sequenceWeighting);
                CorrelatedMutations.this.recentWeightingScheme = sequenceWeighting;
            }
            this.posI = CorrelatedMutations.this.seqResMap[i];
            this.posJ = CorrelatedMutations.this.seqResMap[i2];
            double[] weights = sequenceWeighting.getWeights(this.posI);
            byte[] bArr = CorrelatedMutations.this.msaAsBytes[this.posI];
            byte[] bArr2 = CorrelatedMutations.this.msaAsBytes[this.posJ];
            double[] weights2 = sequenceWeighting.getWeights(this.posJ);
            CorrelatedMutations.jbcl_logger.fine(String.format("Joint mutation distribution for sequence positions: %d,%d (MSA columns: %d,%d)", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.posI), Integer.valueOf(this.posJ)));
            for (int i3 = 0; i3 < CorrelatedMutations.this.alignedSequences.length; i3++) {
                double[] dArr = this.observedCounts[bArr[i3]];
                byte b = bArr2[i3];
                dArr[b] = dArr[b] + Math.sqrt(weights[i3] * weights2[i3]);
            }
            for (int i4 = 0; i4 < 21; i4++) {
                double probability = CorrelatedMutations.this.sequenceProfile.getProbability(i, i4);
                for (int i5 = 0; i5 < 21; i5++) {
                    this.correctedCounts[i4][i5] = this.observedCounts[i4][i5] - (probability * CorrelatedMutations.this.sequenceProfile.getProbability(i2, i5));
                }
            }
        }

        public int countCorrelatedPairs(double d) {
            int i = 0;
            for (int i2 = 0; i2 < 20; i2++) {
                for (int i3 = 0; i3 < 20; i3++) {
                    if (this.correctedCounts[i2][i3] >= d) {
                        i++;
                    }
                }
            }
            return i;
        }

        public double sumCorrectedCounts() {
            double d = 0.0d;
            for (int i = 0; i < 20; i++) {
                for (int i2 = 0; i2 < 20; i2++) {
                    d += this.correctedCounts[i][i2];
                }
            }
            return d;
        }

        public TwoTuple<Integer, Integer> bestCorrectedCounts() {
            double d = 0.0d;
            int i = -1;
            int i2 = -1;
            for (int i3 = 0; i3 < 20; i3++) {
                for (int i4 = 0; i4 < 20; i4++) {
                    if (this.correctedCounts[i3][i4] > d) {
                        d = this.correctedCounts[i3][i4];
                        i = i3;
                        i2 = i4;
                    }
                }
            }
            return Tuple.tuple(Integer.valueOf(i), Integer.valueOf(i2));
        }

        public double[] eigenvalues() {
            if (this.eig == null) {
                this.eig = new Matrix(this.correctedCounts).eig();
            }
            return this.eig.getRealEigenvalues();
        }
    }

    public CorrelatedMutations(String[] strArr, int i) {
        super(strArr, i);
        this.recentWeightingScheme = null;
        this.sequenceProfile = null;
        this.msaAsBytes = new byte[strArr[0].length()][strArr.length];
        for (int i2 = 0; i2 < strArr.length; i2++) {
            for (int i3 = 0; i3 < strArr[i2].length(); i3++) {
                byte id = (byte) (strArr[i2].charAt(i3) == SequenceWeighting.gapSymbol ? 20 : MonomersFactory.getId(strArr[i2].charAt(i3)));
                this.msaAsBytes[i3][i2] = id;
                if (id == -1) {
                    this.msaAsBytes[i3][i2] = 20;
                }
            }
        }
    }

    public final JointAADistribution jointAADistribution(int i, int i2, SequenceWeighting sequenceWeighting) {
        return new JointAADistribution(i, i2, sequenceWeighting);
    }

    public static final CorrelatedMutations fromClustalW(String str, String str2) throws FileNotFoundException, IOException {
        SequenceWeighting.gapSymbol = '-';
        TwoTuple<String, String>[] readStrings = ClustalW.readStrings(GZipAwareBufferedReader.getReader(str));
        String[] strArr = new String[readStrings.length];
        int i = -1;
        for (int i2 = 0; i2 < readStrings.length; i2++) {
            strArr[i2] = readStrings[i2].second;
            if (readStrings[i2].first.contains(str2)) {
                i = i2;
            }
        }
        jbcl_logger.fine("Got " + strArr.length + " sequences from a ClustalW file: " + str);
        return new CorrelatedMutations(strArr, i);
    }

    public static final CorrelatedMutations fromFasta(String str, int i) throws FileNotFoundException, IOException {
        return new CorrelatedMutations(DataTable.fromBuffer(GZipAwareBufferedReader.getReader(str)).getStringColumn(0), i);
    }

    public static void main(String[] strArr) throws FileNotFoundException, IOException {
        JointAADistribution jointAADistribution;
        int countCorrelatedPairs;
        String str = strArr[0];
        String str2 = strArr[1];
        String str3 = strArr[2];
        double parseDouble = strArr.length > 3 ? Double.parseDouble(strArr[3]) : 0.8d;
        Residue[] residuesArray = strArr.length > 4 ? new PDB(strArr[4]).getStructure().getResiduesArray() : null;
        CorrelatedMutations fromClustalW = fromClustalW(str, str2);
        ClusterWeighting clusterWeighting = new ClusterWeighting(fromClustalW.alignedSequences, parseDouble);
        SequenceProfile computeProfile = fromClustalW.computeProfile(clusterWeighting);
        FlatSequenceProfile.write(computeProfile, str3);
        for (int i = 0; i < fromClustalW.theSequence.length; i++) {
            if (computeProfile.getProbability(i, 20) <= 0.5d) {
                Monomer mostProbableMonomer = computeProfile.getMostProbableMonomer(i);
                if (computeProfile.getProbability(i, mostProbableMonomer) <= 0.9d) {
                    for (int i2 = 0; i2 < fromClustalW.theSequence.length; i2++) {
                        if (i - i2 >= 3 && computeProfile.getProbability(i2, 20) <= 0.5d) {
                            Monomer mostProbableMonomer2 = computeProfile.getMostProbableMonomer(i2);
                            if (computeProfile.getProbability(i2, mostProbableMonomer2) <= 0.9d && (countCorrelatedPairs = (jointAADistribution = fromClustalW.jointAADistribution(i, i2, clusterWeighting)).countCorrelatedPairs(0.05d)) >= 2) {
                                double sumCorrectedCounts = jointAADistribution.sumCorrectedCounts();
                                TwoTuple<Integer, Integer> bestCorrectedCounts = jointAADistribution.bestCorrectedCounts();
                                int intValue = bestCorrectedCounts.first.intValue();
                                int intValue2 = bestCorrectedCounts.second.intValue();
                                char c = MonomersFactory.get(intValue).oneLetterCode;
                                char c2 = MonomersFactory.get(intValue2).oneLetterCode;
                                if (residuesArray == null) {
                                    System.err.printf("%3d %c %3d %c %c %c %f %f\n", Integer.valueOf(i), Character.valueOf(fromClustalW.theSequence.getEntity(i).oneLetterCode), Integer.valueOf(i2), Character.valueOf(fromClustalW.theSequence.getEntity(i2).oneLetterCode), Character.valueOf(c), Character.valueOf(c2), Double.valueOf(jointAADistribution.correctedCounts[intValue][intValue2]), Double.valueOf(jointAADistribution.observedCounts[intValue][intValue2]));
                                } else {
                                    System.err.printf("%3d %c %3d %c %c %c %f %f %f  %c %c %f %d\n", Integer.valueOf(i), Character.valueOf(fromClustalW.theSequence.getEntity(i).oneLetterCode), Integer.valueOf(i2), Character.valueOf(fromClustalW.theSequence.getEntity(i2).oneLetterCode), Character.valueOf(c), Character.valueOf(c2), Double.valueOf(jointAADistribution.correctedCounts[intValue][intValue2]), Double.valueOf(jointAADistribution.observedCounts[intValue][intValue2]), Double.valueOf(MinDistance.calculateValue(residuesArray[i].getAtomsArray(), residuesArray[i2].getAtomsArray())), Character.valueOf(mostProbableMonomer.oneLetterCode), Character.valueOf(mostProbableMonomer2.oneLetterCode), Double.valueOf(sumCorrectedCounts), Integer.valueOf(countCorrelatedPairs));
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}
