Code Overview


Source Code

GitHub Link: https://github.com/flowjo-lakes/Deep-Learning-Plugin


Referenced Libraries

-FlowJo Plugin API https://www.gitbook.com/book/flowjollc/flowjo-plugin-developers-guide

-BatchEffectRemoval repository https://github.com/ushaham/BatchEffectRemoval


Java Methods

getScriptFile(File absolutePath)

Locates the python script packaged in the deepLearning.jar file and loads the script onto the local file system where the script may be accessed.

 private File getScriptFile(File absolutePath) {
        if(gScriptFile == null) 
        {
            InputStream findScriptPath = this.getClass().getClassLoader().getResourceAsStream("python/train_MMD_ResNet.py");
            if(findScriptPath != null) 
            {
                try 
                {
                File scriptFile = new File(absolutePath, "train_MMD_ResNet.py");
                FileUtil.copyStreamToFile(findScriptPath, scriptFile);
                gScriptFile = scriptFile;
                } 
                catch (Exception exception) 
                {
                    System.out.println("Script not found");
                }
            System.out.println("Script found");
            }
        }    
        return gScriptFile;
    }
executePython(File outputFolder)

Executes the python script with the parameters given from the GUI. Pipes the output of the python script to the debug console.

    public boolean executePython(File outputFolder)
    {
        try 
        {    
            File myPythonFile = getScriptFile(outputFolder);
            System.out.println("Trying to execute python script....\n");            

            if (myPythonFile != null)
            {
                String execLine =   
                    "python" + " "
                    + myPythonFile.getAbsolutePath() + " "
                    + numEpochs + " "
                    + sourcePath + " "
                    + targetPath + " "
                    + outputFolder + " "
                    + resultName + " ";            
                Process proc = Runtime.getRuntime().exec(execLine);
                System.out.println("Working.....");

                //prepare to deliver the output from the python file
                OutputStream stdout = proc.getOutputStream();
                InputStream stdin = proc.getInputStream();
                InputStream stderr = proc.getErrorStream();
                InputStreamReader isrIn = new InputStreamReader(stdin);
                InputStreamReader isr = new InputStreamReader(stderr);
                BufferedReader br = new BufferedReader(isrIn);

                Thread.sleep(1000);

                //deliver the output from the python file
                String line = null;
                while ((line = br.readLine()) != null) {
                    System.out.println(line);
                }
                //wait for the process to finish up
                proc.waitFor();

                System.out.println("Execution successful!\n");

                return true;
            }
        }
        catch (InterruptedException e) 
        {
            e.printStackTrace();
            return false;
        }
        catch (IOException e)
        {
            e.printStackTrace();
            return false;
        } 
        return false;
    }
promptForOptions(SElement fcmlQueryElement, List<String> parameterNames)

SeqGeq interface method. This method is called when a plugin is initialized on a population and is used to display a GUI that allows the user to select genes and the number of epochs that the deep learning algorithm uses to iterate over data. This method only runs once because the input from the user is used on the second .csv file as well.

@Override
    public boolean promptForOptions(SElement fcmlQueryElement, List<String> parameterNames) 
    {    
        //only run this method when the plugin is initialized 
        if (state != pluginState.empty)
            return true;

        ParameterSetMgrInterface mgr = PluginHelper.getParameterSetMgr(fcmlQueryElement);
        if (mgr == null)
            return false;
        List<Object> guiObjects = new ArrayList<Object>();
        FJLabel explainText = new FJLabel();
        guiObjects.add(explainText);

        explainText = new FJLabel();
        guiObjects.add(explainText);
        String text = "<html><body>";
        text += "Enter the number of Epochs and select the <br>genes you want included in the resultant csv file";
        text += "</body></html>";
        explainText.setText(text);
        // entry
        FJLabel label = new FJLabel("Number of Epochs (5 - 600) ");
        String tip = "A higher number of Epochs will result in more accurately trained data but takes longer.";
        label.setToolTipText(tip);
        RangedIntegerTextField epochInputField = new RangedIntegerTextField(5, 600);
        epochInputField.setInt(numEpochs);
        epochInputField.setToolTipText(tip);
        GuiFactory.setSizes(epochInputField, new Dimension(50, 25));
        Box box = SwingUtil.hbox(Box.createHorizontalGlue(), label, epochInputField, Box.createHorizontalGlue());
        guiObjects.add(box);

        ParameterSelectionPanel pane = new ParameterSelectionPanel(mgr,
                eParameterSelectionMode.WithSetsAndParameters, true, false, false, true);
        Dimension dim = new Dimension(300, 500);
        pane.setMaximumSize(dim);
        pane.setMinimumSize(dim);
        pane.setPreferredSize(dim);

        pane.setSelectedParameters(parameterNames);
        parameterNames.clear();

        guiObjects.add(pane);

        int option = JOptionPane.showConfirmDialog(null, guiObjects.toArray(), "Deep Learning Plugin",
                JOptionPane.OK_CANCEL_OPTION, JOptionPane.PLAIN_MESSAGE, null);

        if (option == JOptionPane.OK_OPTION) 
        {

            // user clicked ok, get all selected parameters
            fParameters.addAll(pane.getParameterSelection());

            // make sure 'CellId' is included
            if (!fParameters.contains("CellId"))
                fParameters.add("CellId");

            // get other GUI inputs
            numEpochs = epochInputField.getInt();
            return true;
        } 
        else
            return false;
    }
setElement(SElement element)

SeqGeq interface method. This method is called automatically by SeqGeq and is used to retrieve stored XML data across plugin invocations.

@Override
    public void setElement(SElement element) 
    {
        SElement params = element.getChild("Parameters");
        if (params == null)
            return;
        fParameters.clear();
        for (SElement elem : params.getChildren()) {
            fParameters.add(elem.getString("name"));
        }

        numEpochs = element.getInt("numEpochs");
        state = pluginState.valueOf(element.getString("state"));
        if (state == pluginState.learned)
        {
            sourcePath = element.getString("sourcePath");
            resultName = element.getString("resultName");
        }
    }
getElement(SElement element)

SeqGeq interface method. This method is called automatically by SeqGeq and is used to save XML data across plugin invocations to be later retrieved by setElement().

@Override
    public SElement getElement() 
    {        
        SElement result = new SElement(getName());
        // store the parameters the user selected
        if (!fParameters.isEmpty()) {
            SElement elem = new SElement("Parameters");
            result.addContent(elem);
            for (String pName : fParameters) {
                SElement e = new SElement("P");
                e.setString("name", pName);
                elem.addContent(e);
            }
        }
        result.setInt("numEpochs", numEpochs);
        result.setString("state", state.toString());
        if (state == pluginState.learned)
        {
            result.setString("resultName", resultName.toString());
            result.setString("sourcePath", sourcePath.toString());
        }
        return result;
invokeAlgorithm(SElement fcmlElem, File sampleFile, File outputFolder)

SeqGeq interface method. This method is called automatically when the plugin is invoked. On the first invocation, the method is used to locally store the parameters the user selected in the GUI. On the second invocation, the ExecutePython() method is invoked, which executes the deep learning algorithm. If the ExecutePython method was successful, the resultant .csv file is imported into the users current workspace

@Override
   public ExternalAlgorithmResults invokeAlgorithm(SElement fcmlElem, File sampleFile, File outputFolder) {
        //ExternalAlgorithmResults results = new ExternalAlgorithmResults();
        SeqGeqExternalAlgorithmResults results = new SeqGeqExternalAlgorithmResults();
        //initial plugin call
        if (state == pluginState.empty)
        {
            sourcePath = sampleFile.getAbsolutePath();
            String fileName = sampleFile.getName().replace(' ', '_');
            resultName = fileName.replace(".csv..ExtNode.csv", "");
            state = pluginState.learned;
        }
        //second call
        else if (state == pluginState.learned)
        {
            targetPath = sampleFile.getAbsolutePath();
            //the ready state prevents users from continuing to use the plugin outside 
            // of it's intended use
            state = pluginState.ready;
            outputPath = outputFolder.getAbsolutePath();

            if (executePython(outputFolder)) {        
                //Import the resultant CSV file into the current workspace
                FJPluginHelper.loadSamplesIntoWorkspace(fcmlElem, new String[]{
                        outputFolder + "/" + resultName + "_" + numEpochs + "E_" + "DL.csv",});
            }
        }
        return results;
    }

Python Functions

The original python deep learning implementation file can be found in the BatchEffectRemoval repository, which is linked at the top of this chapter in the "Referenced Libraries". The name of the python file is "train_MMD_ResNet.py". A few changes were necessary to adapt the original file to the needs of the plugin. The changes are as follows:

  • Removed hard coded values and replaced with command line argument variables
  • Removed the option to denoise .csv files
  • Removed the line that saves a .h5 weights files
  • Added a line that saves the calibrated input source file to the file system
  • Wrote formatCSV() function (described below)
formatCSV( path )

Formats the input CSV file by stripping the column headers and row identifiers so that the deep learning script may recognize the file. The column headers and row identifiers are added back at the end of the script.

####################
# CSV reformatting #
####################
def formatCSV( path ):
    inputFile = path
    noHeader = path + "/../swapFile.csv"
    result = path + "/../Result.csv"

    #get the column headers
    with open(inputFile) as f:
        reader = csv.reader(f, delimiter=',')
        global columnHeaders
        columnHeaders = next(reader)

    #delete the column headers, save the resultant in a different file
    with open(inputFile,'r') as f:
        with open(noHeader,'w') as f1:
            f.readline() # skip header line
            for line in f:
                f1.write(line)

    #get the row identifiers
    with open(noHeader, 'r') as f:
        reader = csv.reader(f)
        global rowIdentifiers
        rowIdentifiers = []
        for row in reader:
            content = list(row[i] for i in [0])
            rowIdentifiers.append(content)

    #delete the row identifiers
    with open(noHeader, "r") as fp_in, open(result, "w", encoding="UTF-8") as fp_out:
        reader = csv.reader(fp_in)
        writer = csv.writer(fp_out, lineterminator='\n')
        for row in reader:
            del row[0]
            writer.writerow(row)

    # clean up extraneous file...
    os.remove(noHeader)

    # At this point the file is ready to be processed by the deep learning algorithm
    return result

    #
    #the rest of the script
    #

######################################
# Reformat csv back to SeqGeq format #
######################################

#read back the resultant data
with open(resultantFile, 'r') as f:
    reader = csv.reader(f)
    bareData = f.readlines()

#add the headers and row identifiers back to the data
with open(resultantFile, 'w') as f:
    wr = csv.writer(f, lineterminator='\n', delimiter=',')
    wr.writerow(columnHeaders)
    reader = csv.reader(bareData)
    for list in rowIdentifiers:
        wr.writerow(list + next(reader))

with open(resultantFile, 'r') as infile, open(realResultPath, 'w') as outfile:
        data = infile.read()
        data = data.replace("e+00", "")
        data = data.replace("+", "")
        outfile.write(data)

#clean up extraneous file
os.remove(sourcePath)
os.remove(resultantFile)

results matching ""

    No results matching ""