package com.antfin.rayag.myUDF;
import com.antgroup.geaflow.common.type.primitive.IntegerType;import com.antgroup.geaflow.common.type.primitive.StringType;import com.antgroup.geaflow.dsl.common.algo.AlgorithmRuntimeContext;import com.antgroup.geaflow.dsl.common.algo.AlgorithmUserFunction;import com.antgroup.geaflow.dsl.common.data.RowEdge;import com.antgroup.geaflow.dsl.common.data.RowVertex;import com.antgroup.geaflow.dsl.common.data.impl.ObjectRow;import com.antgroup.geaflow.dsl.common.data.impl.types.IntVertex;import com.antgroup.geaflow.dsl.common.function.Description;import com.antgroup.geaflow.dsl.common.types.StructType;import com.antgroup.geaflow.dsl.common.types.TableField;import com.antgroup.geaflow.model.graph.edge.EdgeDirection;
import java.util.ArrayList;import java.util.Iterator;import java.util.List;
@Description(name = "khop", description = "built-in udga for KHop")public class KHop implements AlgorithmUserFunction<Object, Integer> {
private AlgorithmRuntimeContext<Object, Integer> context; private int srcId = 1; private int k = 1;
@Override public void init(AlgorithmRuntimeContext<Object, Integer> context, Object[] parameters) { this.context = context; if (parameters.length > 2) { throw new IllegalArgumentException( "Only support zero or more arguments, false arguments " + "usage: func([alpha, [convergence, [max_iteration]]])"); } if (parameters.length > 0) { srcId = Integer.parseInt(String.valueOf(parameters[0])); } if (parameters.length > 1) { k = Integer.parseInt(String.valueOf(parameters[1])); } }
@Override public void process(RowVertex vertex, Iterator<Integer> messages) { List<RowEdge> outEdges = new ArrayList<>(context.loadEdges(EdgeDirection.OUT)); //第一轮迭代将所有顶点初始化,目标点的K值初始化为0,并向邻点发送消息,其他点的K值初始化为Integer.MAX_VALUE if (context.getCurrentIterationId() == 1L) { if(srcId == (int) vertex.getId()) { sendMessageToNeighbors(outEdges, 1); context.updateVertexValue(ObjectRow.create(0)); context.take(ObjectRow.create(vertex.getId(), 0)); }else{ context.updateVertexValue(ObjectRow.create(Integer.MAX_VALUE)); } } else if (context.getCurrentIterationId() <= k+1) { int currentK = (int) vertex.getValue().getField(0, IntegerType.INSTANCE); //如果当前顶点收到消息,并且K值为Integer.MAX_VALUE(没有被遍历到),则本轮应该修改K值,并向邻边发消息 if(messages.hasNext() && currentK == Integer.MAX_VALUE){ Integer currK = messages.next(); //将当前顶点写出 context.take(ObjectRow.create(vertex.getId(), currK)); //更新当前顶点的K值 context.updateVertexValue(ObjectRow.create(currK)); //向邻点发消息 sendMessageToNeighbors(outEdges, currK+1); } } }
//设置输出类型 @Override public StructType getOutputType() { return new StructType( new TableField("id", IntegerType.INSTANCE, false), new TableField("k", IntegerType.INSTANCE, false) ); }
private void sendMessageToNeighbors(List<RowEdge> outEdges, Integer message) { for (RowEdge rowEdge : outEdges) { context.sendMessage(rowEdge.getTargetId(), message); } }}
评论