Skip to content

Commit

Permalink
fix: include universal patterns in bindings (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored Jul 22, 2024
1 parent 3de3c32 commit 3eb6cbf
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 29 deletions.
3 changes: 2 additions & 1 deletion crates/cli/src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ impl<'b> RichPattern<'b> {
) -> Result<CompilationResult> {
let lang = language.unwrap_or_default();
#[cfg(not(feature = "ai_builtins"))]
let injected_builtins: Option<BuiltIns> = None;
let injected_builtins: Option<BuiltIns> =
marzano_core::built_in_functions::get_ai_placeholder_functions();
#[cfg(feature = "ai_builtins")]
let injected_builtins = Some(ai_builtins::ai_builtins::get_ai_built_in_functions());

Expand Down
4 changes: 2 additions & 2 deletions crates/cli_bin/tests/snapshots/apply__output_jsonl.snap

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ variables:
scopedName: 0_3_$absolute_filename
ranges: []
- name: $body
scopedName: 12_0_$body
scopedName: 13_0_$body
ranges:
- start:
line: 1
Expand All @@ -29,7 +29,7 @@ variables:
startByte: 15
endByte: 20
- name: $match
scopedName: 12_1_$match
scopedName: 13_1_$match
ranges: []
sourceFile: "`function () { $body }`"
parsedPattern: "[..]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ variables:
scopedName: 0_3_$absolute_filename
ranges: []
- name: $match
scopedName: 12_0_$match
scopedName: 13_0_$match
ranges: []
- name: $body
scopedName: 13_0_$body
scopedName: 14_0_$body
ranges:
- start:
line: 1
Expand All @@ -32,7 +32,7 @@ variables:
startByte: 32
endByte: 37
- name: $args
scopedName: 13_1_$args
scopedName: 14_1_$args
ranges:
- start:
line: 1
Expand All @@ -43,7 +43,7 @@ variables:
startByte: 23
endByte: 28
- name: $body
scopedName: 14_0_$body
scopedName: 15_0_$body
ranges:
- start:
line: 1
Expand All @@ -54,7 +54,7 @@ variables:
startByte: 63
endByte: 68
- name: $args
scopedName: 14_1_$args
scopedName: 15_1_$args
ranges:
- start:
line: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ variables:
scopedName: 0_3_$absolute_filename
ranges: []
- name: $match
scopedName: 13_0_$match
scopedName: 14_0_$match
ranges: []
sourceFile: "engine marzano(0.1)\nlanguage js\n\nfunction adder() js {\n console.log(\"We are in JavaScript now!\");\n return 10 % 3\n}\n\n`x` => adder()"
parsedPattern: "[..]"
Expand Down
4 changes: 2 additions & 2 deletions crates/cli_bin/tests/snapshots/parse__parses_grit_file.snap
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ variables:
scopedName: 0_3_$absolute_filename
ranges: []
- name: $msg
scopedName: 12_0_$msg
scopedName: 13_0_$msg
ranges:
- start:
line: 3
Expand All @@ -29,7 +29,7 @@ variables:
startByte: 26
endByte: 30
- name: $match
scopedName: 12_1_$match
scopedName: 13_1_$match
ranges: []
sourceFile: "language js\n\n`console.log($msg)`\n"
parsedPattern: "[..]"
Expand Down
23 changes: 23 additions & 0 deletions crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,26 @@ fn length_fn<'a>(
None => Err(anyhow!("length argument must be a list or string")),
}
}

pub fn get_ai_placeholder_functions() -> Option<BuiltIns> {
Some(
vec![
BuiltInFunction::new(
"llm_chat",
vec!["model", "messages", "pattern"],
Box::new(ai_fn_placeholder),
),
BuiltInFunction::new("embedding", vec!["target"], Box::new(ai_fn_placeholder)),
]
.into(),
)
}

fn ai_fn_placeholder<'a>(
_args: &'a [Option<Pattern<MarzanoQueryContext>>],
_context: &'a MarzanoContext<'a>,
_state: &mut State<'a, MarzanoQueryContext>,
_logs: &mut AnalysisLogs,
) -> Result<MarzanoResolvedPattern<'a>> {
bail!("AI features are not supported in your GritQL distribution. Please upgrade to the Enterprise version to use AI features.")
}
3 changes: 2 additions & 1 deletion crates/core/src/pattern_compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ pub(crate) fn filter_libs(
foreign_functions: foreign_file,
} = defs_to_filenames(libs, parser, tree.root_node())?;
let mut filtered: BTreeMap<String, String> = BTreeMap::new();
// gross but necessary due to running these patterns befor and after each file

// gross but necessary due to running these patterns befor and after each file
let mut stack: Vec<Tree> = if will_autowrap {
let before_each_file = "before_each_file()";
let before_tree =
Expand Down Expand Up @@ -545,6 +545,7 @@ pub(crate) fn filter_libs(
}
}
}

Ok(filtered.into_iter().collect_vec())
}

Expand Down
9 changes: 7 additions & 2 deletions crates/core/src/problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ impl Problem {
}

pub fn definitions(&self) -> StaticDefinitions<'_, MarzanoQueryContext> {
StaticDefinitions::new(
let mut defs = StaticDefinitions::new(
&self.pattern_definitions,
&self.predicate_definitions,
&self.function_definitions,
)
);
// We use the first 3 indexes for auto-wrap stuff in production
if self.pattern_definitions.len() >= 3 {
defs.skippable_indexes = vec![0, 1, 2];
}
defs
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/grit-pattern-matcher/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pub struct StaticDefinitions<'a, Q: QueryContext> {
pattern_definitions: &'a [PatternDefinition<Q>],
predicate_definitions: &'a [PredicateDefinition<Q>],
function_definitions: &'a [GritFunctionDefinition<Q>],
/// Pattern indexes we should skip during analysis (before_each_file / after_each_file)
pub skippable_indexes: Vec<usize>,
}

impl<'a, Q: QueryContext> StaticDefinitions<'a, Q> {
Expand All @@ -87,10 +89,14 @@ impl<'a, Q: QueryContext> StaticDefinitions<'a, Q> {
pattern_definitions,
predicate_definitions,
function_definitions,
skippable_indexes: vec![],
}
}

pub fn get_pattern(&self, index: usize) -> Option<&PatternDefinition<Q>> {
if self.skippable_indexes.contains(&index) {
return None;
}
self.pattern_definitions.get(index)
}

Expand All @@ -109,6 +115,7 @@ impl<'a, Q: QueryContext> Default for StaticDefinitions<'a, Q> {
pattern_definitions: &[],
predicate_definitions: &[],
function_definitions: &[],
skippable_indexes: vec![],
}
}
}
32 changes: 20 additions & 12 deletions crates/gritmodule/src/patterns_directory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl PatternsDirectory {
}
}

pub fn get_language_directory(&self, lang: PatternLanguage) -> &BTreeMap<String, String> {
fn get_language_directory(&self, lang: PatternLanguage) -> &BTreeMap<String, String> {
match lang {
PatternLanguage::JavaScript => &self.java_script,
PatternLanguage::TypeScript => &self.type_script,
Expand Down Expand Up @@ -147,6 +147,10 @@ impl PatternsDirectory {
}
}

fn get_universal(&self) -> &BTreeMap<String, String> {
self.get_language_directory(PatternLanguage::Universal)
}

#[tracing::instrument]
fn get_language_and_universal_directory(
&self,
Expand All @@ -157,9 +161,7 @@ impl PatternsDirectory {
};
let lang_library = self.get_language_directory(language);
let mut lang_library = lang_library.to_owned();
let universal = self
.get_language_directory(PatternLanguage::Universal)
.to_owned();
let universal = self.get_universal().to_owned();
let count = lang_library.len() + universal.len();
lang_library.extend(universal);
if count != lang_library.len() {
Expand All @@ -176,12 +178,6 @@ impl PatternsDirectory {
self.get_language_and_universal_directory(language)
}

fn get_language_directory_from_name(&self, name: &str) -> Option<&BTreeMap<String, String>> {
self.pattern_to_language
.get(name)
.map(|l| self.get_language_directory(*l))
}

// imo we should check if name matches [a-z][a-z0-9]*
// as currently a pattern with no language header and an invalid pattern are
// both treated as js patterns when the latter should be a not found error
Expand All @@ -195,9 +191,21 @@ impl PatternsDirectory {
Ok(LanguageLibrary::new(language, library))
}

fn get_language_directory_from_name(&self, name: &str) -> Option<&BTreeMap<String, String>> {
self.pattern_to_language
.get(name)
.map(|l| self.get_language_directory(*l))
}

pub fn get(&self, name: &str) -> Option<&String> {
self.get_language_directory_from_name(name)
.and_then(|d| d.get(name))
if let Some(dir) = self.get_language_directory_from_name(name) {
if let Some(pattern) = dir.get(name) {
return Some(pattern);
}
} else if let Some(pattern) = self.get_universal().get(name) {
return Some(pattern);
}
None
}

// do we want to do an overriding insert?
Expand Down
2 changes: 1 addition & 1 deletion crates/lsp/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub fn check_intersection(range1: &Range, range2: &Range) -> bool {

pub(crate) fn get_ai_built_in_functions_for_feature() -> Option<BuiltIns> {
#[cfg(not(feature = "ai_builtins"))]
return None;
return marzano_core::built_in_functions::get_ai_placeholder_functions();
#[cfg(feature = "ai_builtins")]
return Some(ai_builtins::ai_builtins::get_ai_built_in_functions());
}
Expand Down

0 comments on commit 3eb6cbf

Please sign in to comment.