Skip to content

Commit

Permalink
refactor: retriever module
Browse files Browse the repository at this point in the history
  • Loading branch information
Autoparallel committed Nov 30, 2024
1 parent 4f7b1f1 commit 6d95dc6
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 267 deletions.
2 changes: 1 addition & 1 deletion crates/learner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ pub const IACR_CONFIG: &str = include_str!("../config/retrievers/iacr.toml");
pub mod prelude {
pub use crate::{
database::DatabaseInstruction, error::LearnerError, resource::Resource,
retriever::ResponseProcessor,
retriever::response::ResponseProcessor,
};
}

Expand Down
8 changes: 8 additions & 0 deletions crates/learner/src/resource/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ mod shared;
pub use paper::*;
pub use shared::*;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ResourceType {
Paper,
Book,
// We can add more built-in types as needed
Custom(PathBuf), // For user-defined resource types via config
}

/// Core trait that defines the behavior of a resource in the system.
///
/// This trait provides a common interface for all resource types, whether they are
Expand Down
105 changes: 105 additions & 0 deletions crates/learner/src/retriever/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use response::ResponseFormat;

use super::*;

/// Configuration for a specific paper source retriever.
///
/// This struct defines how to interact with a particular paper source's API,
/// including URL patterns, authentication, and response parsing rules.
///
/// # Examples
///
/// Example TOML configuration:
///
/// ```toml
/// name = "arxiv"
/// base_url = "http://export.arxiv.org/api/query"
/// pattern = "^\\d{4}\\.\\d{4,5}$"
/// source = "arxiv"
/// endpoint_template = "http://export.arxiv.org/api/query?id_list={identifier}"
///
/// [response_format]
/// type = "xml"
/// strip_namespaces = true
///
/// [response_format.field_maps]
/// title = { path = "entry/title" }
/// abstract = { path = "entry/summary" }
/// publication_date = { path = "entry/published" }
/// authors = { path = "entry/author/name" }
/// ```
#[derive(Debug, Clone, Deserialize)]
pub struct RetrieverConfig {
/// Name of this retriever configuration
pub name: String,
/// The type of resource this retriever should yield
pub resource_type: ResourceType,
/// Base URL for API requests
pub base_url: String,
/// Regex pattern for matching and extracting paper identifiers
#[serde(deserialize_with = "deserialize_regex")]
pub pattern: Regex,
/// Source identifier for papers from this retriever
pub source: String,
/// Template for constructing API endpoint URLs
pub endpoint_template: String,
/// Format and parsing configuration for API responses
pub response_format: ResponseFormat,
/// Optional HTTP headers for API requests
#[serde(default)]
pub headers: HashMap<String, String>,
}

impl RetrieverConfig {
/// Extracts the canonical identifier from an input string.
///
/// Uses the configured regex pattern to extract the standardized
/// identifier from various input formats (URLs, DOIs, etc.).
///
/// # Arguments
///
/// * `input` - Input string containing a paper identifier
///
/// # Returns
///
/// Returns a Result containing either:
/// - The extracted identifier as a string slice
/// - A LearnerError if the input doesn't match the pattern
pub fn extract_identifier<'a>(&self, input: &'a str) -> Result<&'a str> {
self
.pattern
.captures(input)
.and_then(|cap| cap.get(1))
.map(|m| m.as_str())
.ok_or(LearnerError::InvalidIdentifier)
}

pub async fn retrieve_paper(&self, input: &str) -> Result<Paper> {
let identifier = self.extract_identifier(input)?;
let url = self.endpoint_template.replace("{identifier}", identifier);

debug!("Fetching from {} via: {}", self.name, url);

let client = reqwest::Client::new();
let mut request = client.get(&url);

// Add any configured headers
for (key, value) in &self.headers {
request = request.header(key, value);
}

let response = request.send().await?;
let data = response.bytes().await?;

trace!("{} response: {}", self.name, String::from_utf8_lossy(&data));

let response_processor = match &self.response_format {
ResponseFormat::Xml(config) => config as &dyn ResponseProcessor,
ResponseFormat::Json(config) => config as &dyn ResponseProcessor,
};
let mut paper = response_processor.process_response(&data).await?;
paper.source = self.source.clone();
paper.source_identifier = identifier.to_string();
Ok(paper)
}
}
Loading

0 comments on commit 6d95dc6

Please sign in to comment.