刚才 内联的 spirv_group 函数 给出了 整个工作组 在所有组内的唯一序号,但是我们怎么去区分组内的线程 id 呢,所以我们还需要一个 spirv_local 函数。
let local = self.builder.variable(ty_pt, None, StorageClass::Input, None); self.builder.decorate(wgid, Decoration::BuiltIn, vec![rspirv::dr::Operand::BuiltIn(spirv::BuiltIn::LocalInvocationId)]); self.builtins.push(local);
复制代码
我们设定 builtin 的 1 号元素 为 local_id
然后使用下面的内联函数获取
"spirv_local"=> { let ty = SpirvType::Var(Type::Vec(Rc::new(Type::U32), 3)); let ty_id = self.get_type(ty.clone()); let id = self.builder.load(ty_id, None, self.builtins[1], None, None)?; self.add_interface(self.builtins[1]); self.set(ret, Var{id, ty}); }
复制代码
编译的时候,指定 这个 shader 每个 group 的线程数,一般而言在桌面显卡上 这个数值不超过 1024
在移动显卡上 最好不超过 256 如果比较新的 移动显卡 可以放宽到 512
我们使用下面的代码 编译 一个组内 256 x 256 的 Shader
println!("{:?}", b.import_module(Module::from(compiler), [32, 32, 1]));
复制代码
我们将 Zeta 的代码改成这样 一个工作组内 按照 x, y 的唯一坐标设置缓冲区对应位置。
pub struct Param { id: [u32; 1024], }
pub fn main(param: Param) { //let id = spirv_group(); let id = spirv_local(); let z = id[1] * 32 + id[0]; param.id[z] = z; }
复制代码
编译之前,需要实现乘法和加法
我们要是实现两个符号的运算 s1 Symbol s2 Symbol
两个符号是我们在 spirv 生成器中管理的值,包括了类型信息 和 SSA 分配的 ID 我们以后 实现循环和分支的时候,会详细解释 SSA 的概念
根据常量和变量 不同 我们定义 符号的的值如下
#[derive(Debug, Default, Clone)]pub struct Var { pub id : u32, pub ty : SpirvType,}
#[derive(Debug, Clone)]pub struct Const { pub id : u32, pub ty: SpirvType, pub val : Dynamic,}
复制代码
拿到一个符号之后 我们首先要取出实际的 SSA ID,所以如果是指针类型 我们要 装载到一个 SSA 中
代码如下
pub fn get_const(&self, s: &Symbol)-> Option<Const> { if let Symbol::Const(idx) = s { Some(self.consts[*idx].clone()) } else { None } }
pub fn get_var(&mut self, var: Var)-> (u32, Type, u32) { match var.ty { SpirvType::Pointer(ty, _)=> { let ty_id = self.get_type(SpirvType::Var(ty.clone())); let id = self.builder.load(ty_id, None, var.id, None, None).unwrap(); (ty_id, ty, id) } SpirvType::Var(ty)=> { let ty_id = self.get_type(SpirvType::Var(ty.clone())); (var.id, ty, ty_id) } _ => panic!("Invalid type") } }
复制代码
然后就是另外一个大问题了,不同类型的 SSA ID 如何运算的
我们当然可以简单的约定 只使用平台的 比如说 u32 f32 但是符号怎么办,所以还是必须进行复杂的类型转换
评论