Skip to content

Commit

Permalink
Switch to Vec instead of BTreeMap for tensors and their grads #15
Browse files Browse the repository at this point in the history
  • Loading branch information
keyvank committed Jun 12, 2023
1 parent ad10166 commit 3daece3
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ unsafe impl Sync for Computation {}

#[derive(Clone)]
pub struct Graph {
tensors: BTreeMap<TensorId, Tensor<f32>>,
grads: BTreeMap<TensorId, Tensor<f32>>,
tensors: Vec<Tensor<f32>>,
grads: Vec<Tensor<f32>>,
computations: BTreeMap<TensorId, Computation>,
}

Expand All @@ -42,12 +42,12 @@ impl Graph {
self.alloc(Tensor::<f32>::rand(rng, shape))
}
fn alloc(&mut self, t: Tensor<f32>) -> TensorId {
let id = self.tensors.len();
self.tensors.insert(id, t);
id
self.grads.push(Tensor::zeros(t.shape()));
self.tensors.push(t);
self.tensors.len() - 1
}
pub fn load<T: TensorOps<f32>>(&mut self, tensor_id: TensorId, tensor: &T) {
self.tensors.insert(tensor_id, tensor.view().into());
self.tensors[tensor_id] = tensor.view().into();
}
pub fn embed<T: TensorOps<usize>>(
&mut self,
Expand All @@ -63,14 +63,16 @@ impl Graph {
Ok(())
}
pub fn load_grad<T: TensorOps<f32>>(&mut self, tensor_id: TensorId, tensor: &T) {
self.grads.insert(tensor_id, tensor.view().into());
self.grads[tensor_id] = tensor.view().into();
}
pub fn zero_grad(&mut self) {
self.grads.clear();
self.grads.iter_mut().for_each(|t| {
t.fill(0.);
});
}
pub fn add_grad<T: TensorOps<f32>>(&mut self, id: TensorId, add: T) -> Result<(), TensorError> {
let shape = self.get(id).shape().to_vec();
let grad = self.grads.entry(id).or_insert(Tensor::zeros(&shape));
let grad = self.grads.get_mut(id).unwrap();
if add.dim() >= shape.len() {
for t in add.keep_right(shape.len())?.inners().iter() {
*grad = (&*grad + t)?;
Expand All @@ -81,10 +83,10 @@ impl Graph {
Ok(())
}
pub fn get(&self, id: TensorId) -> &Tensor<f32> {
self.tensors.get(&id).expect("Tensor not found!")
self.tensors.get(id).expect("Tensor not found!")
}
pub fn get_grad(&self, id: TensorId) -> &Tensor<f32> {
self.grads.get(&id).expect("Tensor not found!")
self.grads.get(id).expect("Tensor not found!")
}
pub fn backward_all(
&mut self,
Expand All @@ -97,18 +99,12 @@ impl Graph {
self.add_grad(id, (&grad * &Tensor::scalar(mean_coeff))?)?;

for (id, comp) in self.computations.clone().iter().rev() {
for inp in comp.inps.iter() {
let shape = self.get(*inp).shape().to_vec();
self.grads
.entry(*inp)
.or_insert(Tensor::<f32>::zeros(&shape));
}
let inps = comp
.inps
.iter()
.map(|id| &self.tensors[id])
.map(|id| &self.tensors[*id])
.collect::<Vec<_>>();
let grad_out = &self.grads[&id];
let grad_out = &self.grads[*id];
let grads = comp.func.grad(&inps, grad_out)?;
for (id, grad) in comp.inps.clone().into_iter().zip(grads.into_iter()) {
self.add_grad(id, grad)?;
Expand All @@ -122,10 +118,10 @@ impl Graph {
let tensors = c
.inps
.iter()
.map(|id| self.tensors.get(id).expect("Tensor not found!"))
.map(|id| self.tensors.get(*id).expect("Tensor not found!"))
.collect::<Vec<_>>();
let result = c.func.run(&tensors, training)?;
self.tensors.insert(*out, result);
self.tensors[*out] = result;
}
Ok(())
}
Expand All @@ -136,7 +132,7 @@ impl Graph {
) -> Result<TensorId, TensorError> {
let tensors = tensor_ids
.iter()
.map(|id| self.tensors.get(id).expect("Tensor not found!"))
.map(|id| self.tensors.get(*id).expect("Tensor not found!"))
.collect::<Vec<_>>();
let out = f.run(&tensors, false)?;
let child = self.alloc(out);
Expand All @@ -158,6 +154,7 @@ impl Graph {
let (params, grads): (Vec<&mut Tensor<f32>>, Vec<&Tensor<f32>>) = self
.tensors
.iter_mut()
.enumerate()
.filter(|(id, _)| params.contains(id))
.map(|(id, params)| {
let grad = self.grads.get(id).expect("Tensor not found!");
Expand Down

0 comments on commit 3daece3

Please sign in to comment.