Skip to content

Commit

Permalink
Choosing input/total tokens automatically based on available VRAM? (#…
Browse files Browse the repository at this point in the history
…2673)

* Choosing input/total tokens automatically based on available VRAM?

* Update doc.

* Remove generated files.

* Trying to fix non chunking targets.

* Attempt #2

* fix.

* QuantLinear is rocm compatible.

* Much simpler logic after the overhead.

* Updating logic + non flash.

* Revert doc text.

* Simple updates.

* Fix integration mt0 (transformers update).
  • Loading branch information
Narsil authored Oct 28, 2024
1 parent 2e4f4ba commit 0c9b6cd
Show file tree
Hide file tree
Showing 14 changed files with 285 additions and 136 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ router/tokenizer.json

backends/v2/src/client/pb
backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb

# ROCm auto-generated files
*.hip
Expand Down
36 changes: 24 additions & 12 deletions backends/client/src/v3/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}

let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
Expand All @@ -136,7 +138,7 @@ impl Client {
// been updated to support chunks.

let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
Expand All @@ -145,6 +147,12 @@ impl Client {
));
}

let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};

requests.push(Request {
id: 0,
inputs,
Expand Down Expand Up @@ -175,15 +183,15 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;

// Check max_batch_size
if Some(requests.len()) == max_batch_size {
Expand All @@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}

/// Generate one token for each request in the given batch
Expand Down
18 changes: 13 additions & 5 deletions backends/client/src/v3/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
Expand All @@ -122,8 +122,16 @@ impl ShardedClient {
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}

/// Generate one token for each request in the given batch
Expand Down
36 changes: 24 additions & 12 deletions backends/v3/src/client/grpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,22 @@ impl Client {
#[instrument(skip_all)]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_tokens: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0;
let mut requests = Vec::new();
// Create requests
while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}

let mut input_chunks = Vec::new();
input_chunks
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
if n_tokens == 0 {
input_chunks.push(
Chunk::Image(Image {
Expand All @@ -137,7 +139,7 @@ impl Client {
// been updated to support chunks.

let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 {
// 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation.
Expand All @@ -146,6 +148,12 @@ impl Client {
));
}

let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};

requests.push(Request {
id: 0,
inputs,
Expand Down Expand Up @@ -175,15 +183,15 @@ impl Client {
grammar_type: GrammarType::None as i32,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate,
max_new_tokens,
stop_sequences: vec![],
ignore_eos_token: true,
}),
prefill_logprobs: true,
top_n_tokens: 20,
adapter_id: None,
});
n_tokens += max_input_length;
n_tokens += truncate;

// Check max_batch_size
if Some(requests.len()) == max_batch_size {
Expand All @@ -195,19 +203,23 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: max_input_length,
max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0,
};

let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_input_tokens,
max_prefill_tokens,
max_total_tokens,
})
.inject_context();
let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens)
Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
}

/// Generate one token for each request in the given batch
Expand Down
19 changes: 13 additions & 6 deletions backends/v3/src/client/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ impl ShardedClient {
#[instrument(skip(self))]
pub async fn warmup(
&mut self,
max_input_length: u32,
max_input_length: Option<u32>,
max_prefill_tokens: u32,
max_total_tokens: u32,
max_total_tokens: Option<u32>,
max_batch_size: Option<usize>,
) -> Result<Option<u32>> {
) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self
.clients
.iter_mut()
Expand All @@ -119,12 +119,19 @@ impl ShardedClient {
))
})
.collect();
// Take the minimum value
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<Option<u32>>>>()?;
Ok(results.into_iter().flatten().min())
.collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;

// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
}

/// Generate one token for each request in the given batch
Expand Down
69 changes: 49 additions & 20 deletions backends/v3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ pub struct BackendInfo {
pub attention_impl: String,
#[schema(example = "1")]
pub block_size: u32,

#[schema(example = "30000")]
pub max_input_tokens: usize,
#[schema(example = "32000")]
pub max_total_tokens: usize,
}

#[allow(clippy::too_many_arguments)]
pub async fn connect_backend(
max_input_tokens: usize,
max_total_tokens: usize,
max_input_tokens: Option<usize>,
max_total_tokens: Option<usize>,
master_shard_uds_path: String,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
Expand All @@ -51,14 +56,32 @@ pub async fn connect_backend(
max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| {
let check_max_batch_total_tokens = |(
max_supported_batch_total_tokens,
shard_max_input_tokens,
shard_max_total_tokens,
): (Option<u32>, u32, u32)|
-> Result<(u32, usize, usize), V3Error> {
if let Some(max_input_tokens) = max_input_tokens {
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
}
if let Some(max_total_tokens) = max_total_tokens {
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
}
match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens
None => {
let max_batch_total_tokens = max_batch_total_tokens
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)));
let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
16000
.max(shard_max_total_tokens)
.max(max_batch_prefill_tokens),
);
tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens)
Ok((
max_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
// Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => {
Expand All @@ -72,11 +95,15 @@ pub async fn connect_backend(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens));
if shard_max_total_tokens > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
}

Ok(max_supported_batch_total_tokens)
Ok((
max_supported_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
}
}
};
Expand All @@ -96,23 +123,25 @@ pub async fn connect_backend(

// Warmup model
tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens(
sharded_client
.warmup(
max_input_tokens as u32,
max_batch_prefill_tokens,
max_total_tokens as u32,
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?,
)?;
let answer = sharded_client
.warmup(
max_input_tokens.map(|p| p as u32),
max_batch_prefill_tokens,
max_total_tokens.map(|p| p as u32),
max_batch_size,
)
.await
.map_err(V3Error::Warmup)?;
let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
check_max_batch_total_tokens(answer)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);

let backend_info = BackendInfo {
waiting_served_ratio,
max_batch_total_tokens,
max_input_tokens,
max_total_tokens,
max_waiting_tokens,
max_batch_size,
model_device_type: shard_info.device_type.clone(),
Expand Down
Loading

0 comments on commit 0c9b6cd

Please sign in to comment.