package com.antfin.rayag.myUDF;
import com.antgroup.geaflow.common.type.primitive.IntegerType;
import com.antgroup.geaflow.common.type.primitive.StringType;
import com.antgroup.geaflow.dsl.common.algo.AlgorithmRuntimeContext;
import com.antgroup.geaflow.dsl.common.algo.AlgorithmUserFunction;
import com.antgroup.geaflow.dsl.common.data.RowEdge;
import com.antgroup.geaflow.dsl.common.data.RowVertex;
import com.antgroup.geaflow.dsl.common.data.impl.ObjectRow;
import com.antgroup.geaflow.dsl.common.data.impl.types.IntVertex;
import com.antgroup.geaflow.dsl.common.function.Description;
import com.antgroup.geaflow.dsl.common.types.StructType;
import com.antgroup.geaflow.dsl.common.types.TableField;
import com.antgroup.geaflow.model.graph.edge.EdgeDirection;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@Description(name = "khop", description = "built-in udga for KHop")
public class KHop implements AlgorithmUserFunction<Object, Integer> {
private AlgorithmRuntimeContext<Object, Integer> context;
private int srcId = 1;
private int k = 1;
@Override
public void init(AlgorithmRuntimeContext<Object, Integer> context, Object[] parameters) {
this.context = context;
if (parameters.length > 2) {
throw new IllegalArgumentException(
"Only support zero or more arguments, false arguments "
+ "usage: func([alpha, [convergence, [max_iteration]]])");
}
if (parameters.length > 0) {
srcId = Integer.parseInt(String.valueOf(parameters[0]));
}
if (parameters.length > 1) {
k = Integer.parseInt(String.valueOf(parameters[1]));
}
}
@Override
public void process(RowVertex vertex, Iterator<Integer> messages) {
List<RowEdge> outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT));
//第一轮迭代将所有顶点初始化,目标点的K值初始化为0,并向邻点发送消息,其他点的K值初始化为Integer.MAX_VALUE
if (context.getCurrentIterationId() == 1L) {
if(srcId == (int) vertex.getId()) {
sendMessageToNeighbors(outEdges, 1);
context.updateVertexValue(ObjectRow.create(0));
context.take(ObjectRow.create(vertex.getId(), 0));
}else{
context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE));
}
} else if (context.getCurrentIterationId() <= k+1) {
int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE);
//如果当前顶点收到消息,并且K值为Integer.MAX_VALUE(没有被遍历到),则本轮应该修改K值,并向邻边发消息
if(messages.hasNext() && currentK == Integer.MAX_VALUE){
Integer currK = messages.next();
//将当前顶点写出
context.take(ObjectRow.create(vertex.getId(), currK));
//更新当前顶点的K值
context.updateVertexValue(ObjectRow.create(currK));
//向邻点发消息
sendMessageToNeighbors(outEdges, currK+1);
}
}
}
//设置输出类型
@Override
public StructType getOutputType() {
return new StructType(
new TableField("id", IntegerType.INSTANCE, false),
new TableField("k", IntegerType.INSTANCE, false)
);
}
private void sendMessageToNeighbors(List<RowEdge> outEdges, Integer message) {
for (RowEdge rowEdge : outEdges) {
context.sendMessage(rowEdge.getTargetId(), message);
}
}
}
评论