/*
 * Decompiled with CFR 0.152.
 */
package edu.msu.cme.rdp.classifier.train.validation.crossvalidate;

import edu.msu.cme.rdp.classifier.train.LineageSequence;
import edu.msu.cme.rdp.classifier.train.LineageSequenceParser;
import edu.msu.cme.rdp.classifier.train.validation.DecisionMaker;
import edu.msu.cme.rdp.classifier.train.validation.HierarchyTree;
import edu.msu.cme.rdp.classifier.train.validation.StatusCount;
import edu.msu.cme.rdp.classifier.train.validation.Taxonomy;
import edu.msu.cme.rdp.classifier.train.validation.TreeFactory;
import edu.msu.cme.rdp.classifier.train.validation.ValidClassificationResultFacade;
import edu.msu.cme.rdp.classifier.train.validation.ValidationClassificationResult;
import edu.msu.cme.rdp.classifier.train.validation.crossvalidate.RdmSelectTaxon;
import edu.msu.cme.rdp.readseq.utils.ResampleSeqFile;
import edu.msu.cme.rdp.readseq.utils.orientation.GoodWordIterator;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class CrossValidate {
    public ArrayList<HashMap> runTest(File tax_file, File source_file, File out_file, String rdmSelectedRank, float fraction, Integer partialLength, boolean useSeed, int min_bootstrap_words) throws IOException {
        BufferedWriter outWriter = new BufferedWriter(new FileWriter(out_file));
        Set<String> selectedTestSeqIDs = null;
        selectedTestSeqIDs = rdmSelectedRank == null ? ResampleSeqFile.randomSelectSeq((File)source_file, (float)fraction) : RdmSelectTaxon.randomSelectTaxon(tax_file, source_file, fraction, rdmSelectedRank);
        TreeFactory factory = this.setup(tax_file, source_file, selectedTestSeqIDs);
        DecisionMaker dm = new DecisionMaker(factory);
        HashMap<String, HierarchyTree> genusNodeMap = new HashMap<String, HierarchyTree>();
        factory.getRoot().getNodeMap(factory.getLowestRank(), genusNodeMap);
        if (genusNodeMap.isEmpty()) {
            throw new IllegalArgumentException("\nThere is no node in GENUS level!");
        }
        HashMap<String, HashSet> rankNodeMap = new HashMap<String, HashSet>();
        for (String rank : factory.getRankSet()) {
            ArrayList<HierarchyTree> nodeList = new ArrayList<HierarchyTree>();
            factory.getRoot().getNodeList(rank, nodeList);
            HashSet<String> nodeNameSet = this.getnodeNameSet(nodeList);
            rankNodeMap.put(rank, nodeNameSet);
        }
        ArrayList<HashMap> statusCountList = new ArrayList<HashMap>();
        for (int b = 0; b <= 100; ++b) {
            HashMap<String, StatusCount> statusCountMap = new HashMap<String, StatusCount>();
            statusCountList.add(statusCountMap);
            for (String rank : factory.getRankSet()) {
                statusCountMap.put(rank, new StatusCount());
            }
        }
        int totalTest = 0;
        int totalSeq = 0;
        LineageSequenceParser parser = new LineageSequenceParser(source_file);
        while (parser.hasNext()) {
            ++totalSeq;
            LineageSequence pSeq = parser.next();
            if (!selectedTestSeqIDs.contains(pSeq.getSeqName()) || pSeq.getSeqString().length() == 0) continue;
            GoodWordIterator wordIterator = null;
            wordIterator = partialLength != null ? pSeq.getPartialSeqIteratorbyGoodBases(partialLength) : new GoodWordIterator(pSeq.getSeqString());
            if (wordIterator == null || wordIterator.getNumofWords() == 0) continue;
            List<ValidationClassificationResult> result = dm.getBestClasspath(wordIterator, genusNodeMap, useSeed, min_bootstrap_words);
            ValidClassificationResultFacade resultFacade = new ValidClassificationResultFacade(pSeq, result);
            this.compareClassificationResult(factory, resultFacade, rankNodeMap, statusCountList);
            ++totalTest;
        }
        parser.close();
        outWriter.write("taxon file\t" + tax_file.getName() + "\ntrain sequence file\t" + source_file.getName() + "\n");
        outWriter.write("word size\t" + GoodWordIterator.getWordsize() + "\n");
        outWriter.write("minimum number of words for bootstrap\t" + min_bootstrap_words + "\n");
        if (partialLength != null) {
            outWriter.write("length\t" + partialLength + "\n");
        } else {
            outWriter.write("length\tfull\n");
        }
        if (rdmSelectedRank == null) {
            outWriter.write("selectedRank\tNA\n");
        } else {
            outWriter.write("selectedRank\t" + rdmSelectedRank + "\n");
        }
        outWriter.write("trainingset size\t" + (totalSeq - selectedTestSeqIDs.size()) + "\n");
        outWriter.write("testset size\t" + totalTest + "\n");
        outWriter.write(this.calErrorRate(statusCountList));
        outWriter.close();
        return statusCountList;
    }

    private TreeFactory setup(File tax_file, File source_file, Set<String> selectedTestSeqIDs) throws IOException {
        TreeFactory factory = new TreeFactory(new FileReader(tax_file));
        LineageSequenceParser parser = new LineageSequenceParser(source_file);
        while (parser.hasNext()) {
            LineageSequence pSeq = parser.next();
            if (selectedTestSeqIDs.contains(pSeq.getSeqName())) continue;
            factory.addSequence(pSeq);
        }
        parser.close();
        factory.calculateWordPrior();
        return factory;
    }

    private HashSet<String> getnodeNameSet(ArrayList<HierarchyTree> genusNodeList) {
        HashSet<String> nodeNameSet = new HashSet<String>();
        for (HierarchyTree t : genusNodeList) {
            nodeNameSet.add(t.getName());
        }
        return nodeNameSet;
    }

    private void compareClassificationResult(TreeFactory factory, ValidClassificationResultFacade resultFacade, HashMap<String, HashSet> rankNodeMap, ArrayList<HashMap> statusCountList) throws IOException {
        HashMap<String, Taxonomy> labeledTaxonMap = new HashMap<String, Taxonomy>();
        labeledTaxonMap.put(factory.getRoot().getTaxonomy().getHierLevel(), factory.getRoot().getTaxonomy());
        int pid = factory.getRoot().getTaxonomy().getTaxID();
        for (int i = 1; i < resultFacade.getAncestors().size(); ++i) {
            Taxonomy tax = factory.getTaxonomy(resultFacade.getSeqName(), resultFacade.getAncestors().get(i), pid, i);
            labeledTaxonMap.put(tax.getHierLevel(), tax);
            pid = tax.getTaxID();
        }
        List<ValidationClassificationResult> hitList = resultFacade.getRankAssignment();
        for (ValidationClassificationResult curRankResult : hitList) {
            int b;
            HashSet nodeNameSet;
            String curRank = curRankResult.getBestClass().getTaxonomy().getHierLevel();
            Taxonomy matchingRankTaxon = (Taxonomy)labeledTaxonMap.get(curRank);
            if (matchingRankTaxon == null || (nodeNameSet = rankNodeMap.get(curRank)) == null) continue;
            int bootstrap = (int)(curRankResult.getNumOfVotes() * 100.0f);
            if (nodeNameSet.contains(matchingRankTaxon.getName())) {
                for (b = 0; b <= bootstrap; ++b) {
                    ((StatusCount)statusCountList.get(b).get(curRank)).incNumTP(1);
                }
                for (b = bootstrap + 1; b < statusCountList.size(); ++b) {
                    ((StatusCount)statusCountList.get(b).get(curRank)).incNumFN(1);
                }
                continue;
            }
            for (b = 0; b <= bootstrap; ++b) {
                ((StatusCount)statusCountList.get(b).get(curRank)).incNumFP(1);
            }
            for (b = bootstrap + 1; b < statusCountList.size(); ++b) {
                ((StatusCount)statusCountList.get(b).get(curRank)).incNumTN(1);
            }
        }
    }

    public String calErrorRate(ArrayList<HashMap> statusCountList) {
        StringBuilder ret = new StringBuilder();
        ret.append("\nbootstrap\t1-Specificity\tSensitivity\n");
        ret.append("bootstrap");
        HashMap statusCountMap = statusCountList.get(0);
        for (String rank : statusCountMap.keySet()) {
            if (rank.startsWith("sub")) continue;
            ret.append("\t" + rank + "_Spec\t" + rank + "_Sens");
        }
        ret.append("\n");
        for (int b = 0; b < statusCountList.size(); ++b) {
            ret.append(b);
            statusCountMap = statusCountList.get(b);
            for (String rank : statusCountMap.keySet()) {
                if (rank.startsWith("sub")) continue;
                StatusCount st = (StatusCount)statusCountMap.get(rank);
                double se = st.calSensitivity();
                double sp = st.calSpecificity();
                ret.append("\t" + (1.0 - sp) + "\t" + se);
            }
            ret.append("\n");
        }
        return ret.toString();
    }
}

