Head's Up! These forums are read-only. All users and content have migrated. Please join us at community.neo4j.com.
08-12-2020 04:57 AM
After cloning this repo:
and initiating the schema on my neo4j db i wanted to create the tree with the command:
CALL com.maxdemarzi.decision_tree.create('credit', '/my/own/path/training.csv', '/my/own/path/answers.csv', 0.02)
Soon i receive this error message:
Failed to invoke procedure ' com.maxdemarzi.decision_tree.create ': Caused by: java.lang.IllegalArgumentException: Must hint overloaded method: toArray
If i call the same function again it throws this error:
Failed to invoke procedure "com.maxdemarzi.decision_tree.create": Caused by: java.lang.NoClassDefFoundError: Could not initialize class clojure.java.api.Clojure
This is the code which most likely causes the error:
package com.maxdemarzi;
import clojure.java.api.Clojure;
import clojure.lang.*;
import com.maxdemarzi.results.StringResult;
import com.maxdemarzi.schema.Labels;
import com.maxdemarzi.schema.RelationshipTypes;
import com.opencsv.CSVIterator;
import com.opencsv.CSVReader;
import org.jblas.DoubleMatrix;
import org.neo4j.graphdb.*;
import org.neo4j.helpers.collection.Pair;
import org.neo4j.logging.Log;
import org.neo4j.procedure.*;
import java.io.*;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class DecisionTreeCreator {
// This field declares that we need a GraphDatabaseService
// as context when any procedure in this class is invoked
@Context
public GraphDatabaseService db;
// This gives us a log instance that outputs messages to the
// standard log, normally found under `data/log/console.log`
@Context
public Log log;
@Procedure(name = "com.maxdemarzi.decision_tree.create", mode = Mode.WRITE)
@Description("CALL com.maxdemarzi.decision_tree.create(tree, data, answers, threshold) - create tree")
public Stream<StringResult> create(@Name("tree") String tree, @Name("data") String data,
@Name("answers") String answers, @Name("threshold") Double threshold ) {
long start = System.nanoTime();
CSVIterator trainingAnswers = getCsvIterator(answers);
CSVIterator trainingData = getCsvIterator(data);
Set<Double> answerSet = new HashSet<>();
HashMap<Double, Node> answerMap = new HashMap<>();
ArrayList<Double> answerList = new ArrayList<>();
while (trainingAnswers.hasNext()) {
String[] nextLine = trainingAnswers.next();
for (String item : nextLine) {
answerList.add(Double.parseDouble(item));
}
}
answerSet.addAll(answerList);
for (Double value : answerSet) {
Node answerNode = db.createNode(Labels.Answer);
answerNode.setProperty("id", value);
answerMap.put(value, answerNode);
}
HashMap<String, Node> nodes = new HashMap<>();
String[] headers = trainingData.next();
for(int i = 0; i < headers.length; i++) {
Node parameter = db.findNode(Labels.Parameter, "name", headers[i]);
if (parameter == null) {
parameter = db.createNode(Labels.Parameter);
parameter.setProperty("name", headers[i]);
parameter.setProperty("type", "double");
parameter.setProperty("prompt", "What is " + headers[i] + "?");
}
nodes.put(headers[i], parameter);
}
double[][] array = new double[answerList.size()][1 + headers.length];
for (int r = 0; r < answerList.size(); r++) {
array[r][0] = answerList.get(r);
String[] columns = trainingData.next();
for (int c = 0; c < columns.length; c++) {
array[r][1 + c] = Double.parseDouble(columns[c]);
}
}
DoubleMatrix fullData = new DoubleMatrix(array);
fullData = fullData.transpose();
int featuresCount = fullData.rows;
int[] rowIndices = IntStream.range(0, featuresCount).toArray();
DoubleMatrix X = fullData.getRows(rowIndices);
/* Import clojure core. */
final IFn require = Clojure.var("clojure.core", "require");
require.invoke(Clojure.read("DecisionStream"));
/* Invoke Clojure trainDStream function. */
final IFn trainFunc = Clojure.var("DecisionStream", "trainDStream");
HashMap dStreamM = new HashMap<>((PersistentArrayMap) trainFunc.invoke(X, rowIndices, threshold));
Node treeNode = db.createNode(Labels.Tree);
treeNode.setProperty("id", tree);
deepLinkMap(db, answerMap, nodes, headers, treeNode, RelationshipTypes.HAS, dStreamM, true);
return Stream.of(new StringResult("Tree created in " + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start) + " seconds"));
}
static void deepLinkMap(GraphDatabaseService db, HashMap<Double, Node> answerMap, HashMap<String, Node> nodes, String[] headers, Node parent, RelationshipType relType, HashMap nestedMap, boolean leftSide) {
RelationshipType leftType = RelationshipTypes.IS_FALSE;
RelationshipType rightType = RelationshipTypes.IS_TRUE;
// We are at a Leaf
if (nestedMap.size() == 2) {
Keyword labelCounts = (Keyword)nestedMap.keySet().toArray()[0];
Collection<PersistentVector> values = (Collection) nestedMap.get(labelCounts);
for (PersistentVector vector : values) {
Node answer = answerMap.get(vector.get(0));
Double weight = (Double)vector.get(1);
Relationship rel = parent.createRelationshipTo(answer, relType);
rel.setProperty("weight", weight);
}
} else {
Node rule;
String key = getKey(headers, nestedMap);
String[] keyParts = key.split("-");
String feature = keyParts[0];
String threshold = keyParts[1];
String parentFeature = (String)parent.getProperty("parameter_names", "");
// Only valid for IS_FALSE same feature children nodes
if(feature.equals(parentFeature) && leftSide) {
rule = parent;
ArrayList<Pair<String, String>> options = new ArrayList<>();
String[] values;
if (rule.hasProperty("values")){
values = (String[])rule.getProperty("values");
} else {
String previousThreshold = ((String)rule.getProperty("expression")).split(" > ")[1];
values = new String[]{previousThreshold};
}
ArrayList<String> thresholds = new ArrayList<>(Arrays.asList(values));
thresholds.add(threshold);
options.add(Pair.of(feature + " > " + thresholds.get(0), "\"IS_TRUE\""));
for (int i = 1; i < thresholds.size(); i++) {
options.add(Pair.of(feature + " <= " + thresholds.get(i - 1) + " && " + feature + " > " + thresholds.get(i), "\"OPTION_" + options.size() + "\""));
if (thresholds.size() - i == 1) {
options.add(Pair.of(feature + " <= " + thresholds.get(i - 1) + " && " + feature + " <= " + thresholds.get(i), "\"OPTION_" + options.size() + "\""));
}
}
rightType = RelationshipType.withName("OPTION_" + (options.size() - 2));
leftType = RelationshipType.withName("OPTION_" + (options.size() - 1));
rule.setProperty("values", thresholds.toArray(new String[]{}));
rule.removeProperty("expression");
StringBuilder script = new StringBuilder();
for (Pair<String, String> pair : options) {
script.append(" if (").append(pair.first()).append(") { return ").append(pair.other()).append(";} ");
}
script.append("return \"NONE\";");
rule.setProperty("script", script.toString());
nodes.put(key, rule);
} else {
if (nodes.containsKey(key)) {
rule = nodes.get(key);
} else {
rule = db.createNode(Labels.Rule);
rule.setProperty("expression", feature + " > " + threshold);
rule.setProperty("parameter_names", feature);
rule.setProperty("parameter_types", "double");
nodes.put(key, rule);
Node parameter = nodes.get(feature);
rule.createRelationshipTo(parameter, RelationshipTypes.REQUIRES);
}
parent.createRelationshipTo(rule, relType);
}
Symbol left = (Symbol) nestedMap.keySet().toArray()[2];
HashMap leftMap = new HashMap<>((PersistentArrayMap) ((Atom) nestedMap.get(left)).deref());
String leftKey = getKey(headers, leftMap);
if (nodes.keySet().contains(leftKey)) {
Node leftNode = nodes.get(leftKey);
rule.createRelationshipTo(leftNode, leftType);
} else {
deepLinkMap(db, answerMap, nodes, headers, rule, leftType, leftMap, true);
}
Symbol right = (Symbol) nestedMap.keySet().toArray()[3];
HashMap rightMap = new HashMap<>((PersistentArrayMap) ((Atom) nestedMap.get(right)).deref());
String rightKey = getKey(headers, rightMap);
if (nodes.keySet().contains(rightKey)) {
Node rightNode = nodes.get(rightKey);
rule.createRelationshipTo(rightNode, rightType);
} else {
deepLinkMap(db, answerMap, nodes, headers, rule, rightType, rightMap, false);
}
}
}
private static String getKey(String[] headers, HashMap map) {
Double threshold = -1.0;
int featureId = -1;
Double leftThreshold = -1.0;
int leftFeatureId = -1;
Double rightThreshold = -1.0;
int rightFeatureId = -1;
String feature = "leaf";
String leftFeature = "leaf";
String rightFeature = "leaf";
if (map.size() > 2) {
Symbol thresholdSymbol = (Symbol) map.keySet().toArray()[0];
Symbol featureIdSymbol = (Symbol) map.keySet().toArray()[1];
threshold = (Double) map.get(thresholdSymbol);
featureId = Math.toIntExact((Long) map.get(featureIdSymbol));
feature = headers[featureId];
Symbol left = (Symbol) map.keySet().toArray()[2];
HashMap leftMap = new HashMap<>((PersistentArrayMap) ((Atom) map.get(left)).deref());
if (leftMap.size() > 2) {
Symbol leftThresholdSymbol = (Symbol) leftMap.keySet().toArray()[0];
Symbol leftFeatureIdSymbol = (Symbol) leftMap.keySet().toArray()[1];
leftThreshold = (Double) leftMap.get(leftThresholdSymbol);
leftFeatureId = Math.toIntExact((Long) leftMap.get(leftFeatureIdSymbol));
leftFeature = headers[leftFeatureId];
}
Symbol right = (Symbol) map.keySet().toArray()[3];
HashMap rightMap = new HashMap<>((PersistentArrayMap) ((Atom) map.get(right)).deref());
if (rightMap.size() > 2) {
Symbol rightThresholdSymbol = (Symbol) rightMap.keySet().toArray()[0];
Symbol rightFeatureIdSymbol = (Symbol) rightMap.keySet().toArray()[1];
rightThreshold = (Double) rightMap.get(rightThresholdSymbol);
rightFeatureId = Math.toIntExact((Long) rightMap.get(rightFeatureIdSymbol));
rightFeature = headers[rightFeatureId];
}
}
return feature + "-" + threshold + "-" + leftFeature + "-" + leftThreshold + "-" + rightFeature + "-" + rightThreshold;
}
private CSVIterator getCsvIterator(String file) {
CSVIterator records = null;
try {
records = new CSVIterator(new CSVReader(new FileReader(file)));
} catch (FileNotFoundException e) {
e.printStackTrace();
log.error("DecisionTreeCreator - File not found: " + file);
} catch (IOException e) {
e.printStackTrace();
log.error("DecisionTreeCreator - IO Exception: " + file);
}
return records;
}
}
The csv files contain 100.000k rows with 10 columns of data.
What could cause these errors?
Solved! Go to Solution.
08-12-2020 09:25 AM
First, you should always include the full exception and trace when asking for help with errors, and the code with line numbers.
That said, I can help.
JDK 11+ have some changes to the toArray()
implementation. To get around this, you'll need to instantiate the array and pass it as an argument to the function.
String[] arr = new String[0];
String[] result = object.toArray(arr);
String somevar = result[0]; // what happens when the array is empty?
// Or, following your code pattern:
String somevar = object.toArray(new String[0])[0];
However, that shouldn't be a problem, because you should be using JDK 8, not 11. Neo4j 3.x requires it, while I believe 4.x uses 11.
If you update your project dependencies to use JDK8, you shouldn't encounter that error... and you'll avoid many other errors that will likely arise using JDK11 with Neo4j 3.x.
08-12-2020 09:25 AM
First, you should always include the full exception and trace when asking for help with errors, and the code with line numbers.
That said, I can help.
JDK 11+ have some changes to the toArray()
implementation. To get around this, you'll need to instantiate the array and pass it as an argument to the function.
String[] arr = new String[0];
String[] result = object.toArray(arr);
String somevar = result[0]; // what happens when the array is empty?
// Or, following your code pattern:
String somevar = object.toArray(new String[0])[0];
However, that shouldn't be a problem, because you should be using JDK 8, not 11. Neo4j 3.x requires it, while I believe 4.x uses 11.
If you update your project dependencies to use JDK8, you shouldn't encounter that error... and you'll avoid many other errors that will likely arise using JDK11 with Neo4j 3.x.
All the sessions of the conference are now available online