Code Overview
Source Code
GitHub Link: https://github.com/flowjo-lakes/Deep-Learning-Plugin
Referenced Libraries
-FlowJo Plugin API https://www.gitbook.com/book/flowjollc/flowjo-plugin-developers-guide
-BatchEffectRemoval repository https://github.com/ushaham/BatchEffectRemoval
Java Methods
getScriptFile(File absolutePath)
Locates the python script packaged in the deepLearning.jar file and loads the script onto the local file system where the script may be accessed.
private File getScriptFile(File absolutePath) {
if(gScriptFile == null)
{
InputStream findScriptPath = this.getClass().getClassLoader().getResourceAsStream("python/train_MMD_ResNet.py");
if(findScriptPath != null)
{
try
{
File scriptFile = new File(absolutePath, "train_MMD_ResNet.py");
FileUtil.copyStreamToFile(findScriptPath, scriptFile);
gScriptFile = scriptFile;
}
catch (Exception exception)
{
System.out.println("Script not found");
}
System.out.println("Script found");
}
}
return gScriptFile;
}
executePython(File outputFolder)
Executes the python script with the parameters given from the GUI. Pipes the output of the python script to the debug console.
public boolean executePython(File outputFolder)
{
try
{
File myPythonFile = getScriptFile(outputFolder);
System.out.println("Trying to execute python script....\n");
if (myPythonFile != null)
{
String execLine =
"python" + " "
+ myPythonFile.getAbsolutePath() + " "
+ numEpochs + " "
+ sourcePath + " "
+ targetPath + " "
+ outputFolder + " "
+ resultName + " ";
Process proc = Runtime.getRuntime().exec(execLine);
System.out.println("Working.....");
//prepare to deliver the output from the python file
OutputStream stdout = proc.getOutputStream();
InputStream stdin = proc.getInputStream();
InputStream stderr = proc.getErrorStream();
InputStreamReader isrIn = new InputStreamReader(stdin);
InputStreamReader isr = new InputStreamReader(stderr);
BufferedReader br = new BufferedReader(isrIn);
Thread.sleep(1000);
//deliver the output from the python file
String line = null;
while ((line = br.readLine()) != null) {
System.out.println(line);
}
//wait for the process to finish up
proc.waitFor();
System.out.println("Execution successful!\n");
return true;
}
}
catch (InterruptedException e)
{
e.printStackTrace();
return false;
}
catch (IOException e)
{
e.printStackTrace();
return false;
}
return false;
}
promptForOptions(SElement fcmlQueryElement, List<String> parameterNames)
SeqGeq interface method. This method is called when a plugin is initialized on a population and is used to display a GUI that allows the user to select genes and the number of epochs that the deep learning algorithm uses to iterate over data. This method only runs once because the input from the user is used on the second .csv file as well.
@Override
public boolean promptForOptions(SElement fcmlQueryElement, List<String> parameterNames)
{
//only run this method when the plugin is initialized
if (state != pluginState.empty)
return true;
ParameterSetMgrInterface mgr = PluginHelper.getParameterSetMgr(fcmlQueryElement);
if (mgr == null)
return false;
List<Object> guiObjects = new ArrayList<Object>();
FJLabel explainText = new FJLabel();
guiObjects.add(explainText);
explainText = new FJLabel();
guiObjects.add(explainText);
String text = "<html><body>";
text += "Enter the number of Epochs and select the <br>genes you want included in the resultant csv file";
text += "</body></html>";
explainText.setText(text);
// entry
FJLabel label = new FJLabel("Number of Epochs (5 - 600) ");
String tip = "A higher number of Epochs will result in more accurately trained data but takes longer.";
label.setToolTipText(tip);
RangedIntegerTextField epochInputField = new RangedIntegerTextField(5, 600);
epochInputField.setInt(numEpochs);
epochInputField.setToolTipText(tip);
GuiFactory.setSizes(epochInputField, new Dimension(50, 25));
Box box = SwingUtil.hbox(Box.createHorizontalGlue(), label, epochInputField, Box.createHorizontalGlue());
guiObjects.add(box);
ParameterSelectionPanel pane = new ParameterSelectionPanel(mgr,
eParameterSelectionMode.WithSetsAndParameters, true, false, false, true);
Dimension dim = new Dimension(300, 500);
pane.setMaximumSize(dim);
pane.setMinimumSize(dim);
pane.setPreferredSize(dim);
pane.setSelectedParameters(parameterNames);
parameterNames.clear();
guiObjects.add(pane);
int option = JOptionPane.showConfirmDialog(null, guiObjects.toArray(), "Deep Learning Plugin",
JOptionPane.OK_CANCEL_OPTION, JOptionPane.PLAIN_MESSAGE, null);
if (option == JOptionPane.OK_OPTION)
{
// user clicked ok, get all selected parameters
fParameters.addAll(pane.getParameterSelection());
// make sure 'CellId' is included
if (!fParameters.contains("CellId"))
fParameters.add("CellId");
// get other GUI inputs
numEpochs = epochInputField.getInt();
return true;
}
else
return false;
}
setElement(SElement element)
SeqGeq interface method. This method is called automatically by SeqGeq and is used to retrieve stored XML data across plugin invocations.
@Override
public void setElement(SElement element)
{
SElement params = element.getChild("Parameters");
if (params == null)
return;
fParameters.clear();
for (SElement elem : params.getChildren()) {
fParameters.add(elem.getString("name"));
}
numEpochs = element.getInt("numEpochs");
state = pluginState.valueOf(element.getString("state"));
if (state == pluginState.learned)
{
sourcePath = element.getString("sourcePath");
resultName = element.getString("resultName");
}
}
getElement(SElement element)
SeqGeq interface method. This method is called automatically by SeqGeq and is used to save XML data across plugin invocations to be later retrieved by setElement().
@Override
public SElement getElement()
{
SElement result = new SElement(getName());
// store the parameters the user selected
if (!fParameters.isEmpty()) {
SElement elem = new SElement("Parameters");
result.addContent(elem);
for (String pName : fParameters) {
SElement e = new SElement("P");
e.setString("name", pName);
elem.addContent(e);
}
}
result.setInt("numEpochs", numEpochs);
result.setString("state", state.toString());
if (state == pluginState.learned)
{
result.setString("resultName", resultName.toString());
result.setString("sourcePath", sourcePath.toString());
}
return result;
invokeAlgorithm(SElement fcmlElem, File sampleFile, File outputFolder)
SeqGeq interface method. This method is called automatically when the plugin is invoked. On the first invocation, the method is used to locally store the parameters the user selected in the GUI. On the second invocation, the ExecutePython() method is invoked, which executes the deep learning algorithm. If the ExecutePython method was successful, the resultant .csv file is imported into the users current workspace
@Override
public ExternalAlgorithmResults invokeAlgorithm(SElement fcmlElem, File sampleFile, File outputFolder) {
//ExternalAlgorithmResults results = new ExternalAlgorithmResults();
SeqGeqExternalAlgorithmResults results = new SeqGeqExternalAlgorithmResults();
//initial plugin call
if (state == pluginState.empty)
{
sourcePath = sampleFile.getAbsolutePath();
String fileName = sampleFile.getName().replace(' ', '_');
resultName = fileName.replace(".csv..ExtNode.csv", "");
state = pluginState.learned;
}
//second call
else if (state == pluginState.learned)
{
targetPath = sampleFile.getAbsolutePath();
//the ready state prevents users from continuing to use the plugin outside
// of it's intended use
state = pluginState.ready;
outputPath = outputFolder.getAbsolutePath();
if (executePython(outputFolder)) {
//Import the resultant CSV file into the current workspace
FJPluginHelper.loadSamplesIntoWorkspace(fcmlElem, new String[]{
outputFolder + "/" + resultName + "_" + numEpochs + "E_" + "DL.csv",});
}
}
return results;
}
Python Functions
The original python deep learning implementation file can be found in the BatchEffectRemoval repository, which is linked at the top of this chapter in the "Referenced Libraries". The name of the python file is "train_MMD_ResNet.py". A few changes were necessary to adapt the original file to the needs of the plugin. The changes are as follows:
- Removed hard coded values and replaced with command line argument variables
- Removed the option to denoise .csv files
- Removed the line that saves a .h5 weights files
- Added a line that saves the calibrated input source file to the file system
- Wrote formatCSV() function (described below)
formatCSV( path )
Formats the input CSV file by stripping the column headers and row identifiers so that the deep learning script may recognize the file. The column headers and row identifiers are added back at the end of the script.
####################
# CSV reformatting #
####################
def formatCSV( path ):
inputFile = path
noHeader = path + "/../swapFile.csv"
result = path + "/../Result.csv"
#get the column headers
with open(inputFile) as f:
reader = csv.reader(f, delimiter=',')
global columnHeaders
columnHeaders = next(reader)
#delete the column headers, save the resultant in a different file
with open(inputFile,'r') as f:
with open(noHeader,'w') as f1:
f.readline() # skip header line
for line in f:
f1.write(line)
#get the row identifiers
with open(noHeader, 'r') as f:
reader = csv.reader(f)
global rowIdentifiers
rowIdentifiers = []
for row in reader:
content = list(row[i] for i in [0])
rowIdentifiers.append(content)
#delete the row identifiers
with open(noHeader, "r") as fp_in, open(result, "w", encoding="UTF-8") as fp_out:
reader = csv.reader(fp_in)
writer = csv.writer(fp_out, lineterminator='\n')
for row in reader:
del row[0]
writer.writerow(row)
# clean up extraneous file...
os.remove(noHeader)
# At this point the file is ready to be processed by the deep learning algorithm
return result
#
#the rest of the script
#
######################################
# Reformat csv back to SeqGeq format #
######################################
#read back the resultant data
with open(resultantFile, 'r') as f:
reader = csv.reader(f)
bareData = f.readlines()
#add the headers and row identifiers back to the data
with open(resultantFile, 'w') as f:
wr = csv.writer(f, lineterminator='\n', delimiter=',')
wr.writerow(columnHeaders)
reader = csv.reader(bareData)
for list in rowIdentifiers:
wr.writerow(list + next(reader))
with open(resultantFile, 'r') as infile, open(realResultPath, 'w') as outfile:
data = infile.read()
data = data.replace("e+00", "")
data = data.replace("+", "")
outfile.write(data)
#clean up extraneous file
os.remove(sourcePath)
os.remove(resultantFile)