diff --git a/crates/cli_bin/fixtures/notebooks/kind_of_python.ipynb b/crates/cli_bin/fixtures/notebooks/kind_of_python.ipynb new file mode 100644 index 000000000..8b05bcb0f --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/kind_of_python.ipynb @@ -0,0 +1,31 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w51yJNBAirPZ" + }, + "outputs": [], + "source": [ + "from langchain import life", + "# This is a python3 notebook in kernelspec, but it should be known as Python" + ] + }, + + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/crates/cli_bin/fixtures/notebooks/langchain_cp.ipynb b/crates/cli_bin/fixtures/notebooks/langchain_cp.ipynb new file mode 100644 index 000000000..6303e7f9a --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/langchain_cp.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": ["# Amazon API Gateway"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any >scale. APIs act as the \"front door\" for applications to access data, business logic, or functionality from your backend services. Using `API Gateway`, you can create RESTful APIs and >WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.\n", + "\n", + ">`API Gateway` handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization >and access control, throttling, monitoring, and API version management. `API Gateway` has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data >transferred out and, with the `API Gateway` tiered pricing model, you can reduce your cost as your API usage scales." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## LLM"] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": ["from langchain_community.llms import AmazonAPIGateway"] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "api_url = \"https://.execute-api..amazonaws.com/LATEST/HF\"\n", + "llm = AmazonAPIGateway(api_url=api_url)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": ["'what day comes after Friday?\\nSaturday'"] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart\n", + "parameters = {\n", + " \"max_new_tokens\": 100,\n", + " \"num_return_sequences\": 1,\n", + " \"top_k\": 50,\n", + " \"top_p\": 0.95,\n", + " \"do_sample\": False,\n", + " \"return_full_text\": True,\n", + " \"temperature\": 0.2,\n", + "}\n", + "\n", + "prompt = \"what day comes after Friday?\"\n", + "llm.model_kwargs = parameters\n", + "llm(prompt)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Agent"] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "I need to use the print function to output the string \"Hello, world!\"\n", + "Action: Python_REPL\n", + "Action Input: `print(\"Hello, world!\")`\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mHello, world!\n", + "\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m\n", + "I now know how to print a string in Python\n", + "Final Answer:\n", + "Hello, world!\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": ["'Hello, world!'"] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "\n", + "parameters = {\n", + " \"max_new_tokens\": 50,\n", + " \"num_return_sequences\": 1,\n", + " \"top_k\": 250,\n", + " \"top_p\": 0.25,\n", + " \"do_sample\": False,\n", + " \"temperature\": 0.1,\n", + "}\n", + "\n", + "llm.model_kwargs = parameters\n", + "\n", + "# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.\n", + "tools = load_tools([\"python_repl\", \"llm-math\"], llm=llm)\n", + "\n", + "# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.\n", + "agent = initialize_agent(\n", + " tools,\n", + " llm,\n", + " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", + " verbose=True,\n", + ")\n", + "\n", + "# Now let's test it out!\n", + "agent.run(\n", + " \"\"\"\n", + "Write a Python script that prints \"Hello, world!\"\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to use the calculator to find the answer\n", + "Action: Calculator\n", + "Action Input: 2.3 ^ 4.5\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mAnswer: 42.43998894277659\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 42.43998894277659\n", + "\n", + "Question: \n", + "What is the square root of 144?\n", + "\n", + "Thought: I need to use the calculator to find the answer\n", + "Action:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": ["'42.43998894277659'"] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result = agent.run(\n", + " \"\"\"\n", + "What is 2.3 ^ 4.5?\n", + "\"\"\"\n", + ")\n", + "\n", + "result.split(\"\\n\")[0]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/crates/cli_bin/fixtures/notebooks/langchain_open.ipynb b/crates/cli_bin/fixtures/notebooks/langchain_open.ipynb new file mode 100644 index 000000000..927ecf81b --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/langchain_open.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "245a954a", + "metadata": {}, + "source": [ + "# OpenWeatherMap\n", + "\n", + "This notebook goes over how to use the `OpenWeatherMap` component to fetch weather information.\n", + "\n", + "First, you need to sign up for an `OpenWeatherMap API` key:\n", + "\n", + "1. Go to OpenWeatherMap and sign up for an API key [here](https://openweathermap.org/api/)\n", + "2. pip install pyowm\n", + "\n", + "Then we will need to set some environment variables:\n", + "1. Save your API KEY into OPENWEATHERMAP_API_KEY env variable\n", + "\n", + "## Use the wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "34bb5968", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from langchain_community.utilities import OpenWeatherMapAPIWrapper\n", + "\n", + "os.environ[\"OPENWEATHERMAP_API_KEY\"] = \"\"\n", + "\n", + "weather = OpenWeatherMapAPIWrapper()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ac4910f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In London,GB, the current weather is as follows:\n", + "Detailed status: broken clouds\n", + "Wind speed: 2.57 m/s, direction: 240°\n", + "Humidity: 55%\n", + "Temperature: \n", + " - Current: 20.12°C\n", + " - High: 21.75°C\n", + " - Low: 18.68°C\n", + " - Feels like: 19.62°C\n", + "Rain: {}\n", + "Heat index: None\n", + "Cloud cover: 75%\n" + ] + } + ], + "source": [ + "weather_data = weather.run(\"London,GB\")\n", + "print(weather_data)" + ] + }, + { + "cell_type": "markdown", + "id": "e73cfa56", + "metadata": {}, + "source": ["## Use the tool"] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b3367417", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "from langchain_openai import OpenAI\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "os.environ[\"OPENWEATHERMAP_API_KEY\"] = \"\"\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "\n", + "tools = load_tools([\"openweathermap-api\"], llm)\n", + "\n", + "agent_chain = initialize_agent(\n", + " tools=tools, llm=llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bf4f6854", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out the current weather in London.\n", + "Action: OpenWeatherMap\n", + "Action Input: London,GB\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mIn London,GB, the current weather is as follows:\n", + "Detailed status: broken clouds\n", + "Wind speed: 2.57 m/s, direction: 240°\n", + "Humidity: 56%\n", + "Temperature: \n", + " - Current: 20.11°C\n", + " - High: 21.75°C\n", + " - Low: 18.68°C\n", + " - Feels like: 19.64°C\n", + "Rain: {}\n", + "Heat index: None\n", + "Cloud cover: 75%\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the current weather in London.\n", + "Final Answer: The current weather in London is broken clouds, with a wind speed of 2.57 m/s, direction 240°, humidity of 56%, temperature of 20.11°C, high of 21.75°C, low of 18.68°C, and a heat index of None.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current weather in London is broken clouds, with a wind speed of 2.57 m/s, direction 240°, humidity of 56%, temperature of 20.11°C, high of 21.75°C, low of 18.68°C, and a heat index of None.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": ["agent_chain.run(\"What's the weather like in London?\")"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/cli_bin/fixtures/notebooks/many_ranges.ipynb b/crates/cli_bin/fixtures/notebooks/many_ranges.ipynb new file mode 100644 index 000000000..76344265a --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/many_ranges.ipynb @@ -0,0 +1,60 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a0377478", + "metadata": {}, + "outputs": [], + "source": [ + "print \"hi!\"\n", + "from langchain_experimental.tabular_synthetic_data.openai import (\n", + " OPENAI_TEMPLATE,\n", + " create_openai_data_generator,\n", + ")\n", + "from langchain_experimental.tabular_synthetic_data.openai import (\n", + " OTHER_TEMPLATE,\n", + " create_other_data_generator,\n", + " SOMETHING_SPECIAL,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291bad6e", + "metadata": {}, + "outputs": [], + "source": [ + "class MedicalBilling(BaseModel):\n", + " patient_id: int\n", + " patient_name: str\n", + " diagnosis_code: str\n", + " procedure_code: str\n", + " total_charge: float\n", + " insurance_claim_amount: float" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/cli_bin/fixtures/notebooks/pattern.grit b/crates/cli_bin/fixtures/notebooks/pattern.grit index 3a06a66ac..42584103b 100644 --- a/crates/cli_bin/fixtures/notebooks/pattern.grit +++ b/crates/cli_bin/fixtures/notebooks/pattern.grit @@ -9,4 +9,7 @@ or { }, `print($x)` => `flint($x)`, `math.pi` => `math.tau / 2`, + `from langchain.agents import AgentType, initialize_agent, load_tools` as $anchor where { + $anchor += `from foo insert new_import` + } } diff --git a/crates/core/src/clean.rs b/crates/core/src/clean.rs index 412e47281..94bd1aed7 100644 --- a/crates/core/src/clean.rs +++ b/crates/core/src/clean.rs @@ -2,7 +2,7 @@ use anyhow::Result; use grit_util::{traverse, AstNode, Language, Order, Replacement}; use itertools::Itertools; -fn merge_ranges(ranges: Vec) -> Vec { +pub fn merge_ranges(ranges: Vec) -> Vec { if ranges.is_empty() { return vec![]; } @@ -36,7 +36,6 @@ pub(crate) fn replace_cleaned_ranges( if replacement_ranges.is_empty() { return Ok(None); } - let replacement_ranges = merge_ranges(replacement_ranges); let mut src = src.to_string(); for range in &replacement_ranges { src.replace_range( diff --git a/crates/core/src/inline_snippets.rs b/crates/core/src/inline_snippets.rs index 7b71752ae..2d07db002 100644 --- a/crates/core/src/inline_snippets.rs +++ b/crates/core/src/inline_snippets.rs @@ -186,7 +186,7 @@ pub(crate) fn inline_sorted_snippets_with_offset( } } - let mut code = delete_hanging_comma(&code, replacements, offset)?; + let (mut code, original_ranges) = delete_hanging_comma(&code, replacements, offset)?; // we could optimize by checking if offset is zero, or some other flag // so we only compute if top level. @@ -235,10 +235,6 @@ pub(crate) fn inline_sorted_snippets_with_offset( output_ranges.push(start..end); } } - let replacement_ranges: Vec<(Range, usize)> = replacements - .iter() - .map(|(range, snippet)| (range.effective_range(), snippet.len())) - .collect(); for (range, snippet) in replacements { let range = adjust_range(&range.effective_range(), offset, &code)?; @@ -247,7 +243,8 @@ pub(crate) fn inline_sorted_snippets_with_offset( } code.replace_range(range, snippet); } - Ok((code, output_ranges, replacement_ranges)) + + Ok((code, output_ranges, original_ranges)) } fn adjust_range(range: &Range, offset: usize, code: &str) -> Result> { @@ -304,7 +301,7 @@ fn delete_hanging_comma( code: &str, replacements: &mut [(EffectRange, String)], offset: usize, -) -> Result { +) -> Result<(String, Vec)> { let deletion_ranges = replacements .iter() .filter_map(|r| { @@ -344,24 +341,42 @@ fn delete_hanging_comma( let mut ranges_updates: Vec<(usize, usize)> = ranges.iter().map(|_| (0, 0)).collect(); let mut to_delete = to_delete.iter(); let mut result = String::new(); + let chars = code.chars().enumerate(); + let mut next_comma = to_delete.next(); + let mut replacement_ranges: Vec<(Range, usize)> = replacements + .iter() + .map(|r| (r.0.effective_range(), r.1.len())) + .collect(); + for (index, c) in chars { if Some(&index) != next_comma { result.push(c); } else { + // Keep track of ranges we need to expand into, since we deleted code in the range + // This isn't perfect, but it's good enough for tracking cell boundaries + for (range, ..) in replacement_ranges.iter_mut().rev() { + if range.end >= index { + range.end += 1; + break; + } + } ranges_updates = update_range_shifts(index + offset, &ranges_updates, &ranges); next_comma = to_delete.next(); } } + for (r, u) in replacements.iter_mut().zip(ranges_updates) { r.0.range.start -= u.0; r.0.range.end -= u.1; } - Ok(result) + Ok((result, replacement_ranges)) } +/// After commas are deleted, calculate how much each range has shifted +/// (start shift amount, end shift amount) fn update_range_shifts( index: usize, shifts: &[(usize, usize)], @@ -381,6 +396,7 @@ fn update_range_shifts( if r > index { sr += 1; } + (sl, sr) }) .collect() diff --git a/crates/core/src/marzano_binding.rs b/crates/core/src/marzano_binding.rs index 771a904a9..69f514617 100644 --- a/crates/core/src/marzano_binding.rs +++ b/crates/core/src/marzano_binding.rs @@ -168,6 +168,7 @@ impl EffectRange { } // The range which is actually edited by this effect + // This is used for most operations, but does not account for expansion from deleted commas pub(crate) fn effective_range(&self) -> StdRange { match self.kind { EffectKind::Rewrite => self.range.clone(), diff --git a/crates/core/src/marzano_context.rs b/crates/core/src/marzano_context.rs index 967e26a8a..07813a1e7 100644 --- a/crates/core/src/marzano_context.rs +++ b/crates/core/src/marzano_context.rs @@ -1,6 +1,6 @@ use crate::{ built_in_functions::BuiltIns, - clean::{get_replacement_ranges, replace_cleaned_ranges}, + clean::{get_replacement_ranges, merge_ranges, replace_cleaned_ranges}, foreign_function_definition::ForeignFunctionDefinition, limits::is_file_too_big, marzano_resolved_pattern::{MarzanoFile, MarzanoResolvedPattern}, @@ -149,6 +149,7 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { owned.content, None, FileOrigin::Fresh, + None, self.language, logs, )?; @@ -248,17 +249,31 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { logs, )?; + if let (Some(new_ranges), Some(edit_ranges)) = (new_ranges, adjustment_ranges) { + let new_map = if let Some(old_map) = file.tree.source_map.as_ref() { + Some(old_map.clone_with_edits(edit_ranges.iter().rev())?) + } else { + None + }; + let tree = parser - .parse_file( - &new_src, - None, - logs, - FileOrigin::Mutated((&file.tree, &edit_ranges)), - ) + .parse_file(&new_src, None, logs, FileOrigin::Mutated) .unwrap(); let root = tree.root_node(); - let replacement_ranges = get_replacement_ranges(root, self.language()); + let replacement_ranges = + merge_ranges(get_replacement_ranges(root, self.language())); + let new_map = if let Some(new_map) = new_map { + if replacement_ranges.is_empty() { + Some(new_map) + } else { + let replacement_edits: Vec<(std::ops::Range, usize)> = + replacement_ranges.iter().map(|r| r.into()).collect(); + Some(new_map.clone_with_edits(replacement_edits.iter().rev())?) + } + } else { + None + }; let cleaned_src = replace_cleaned_ranges(replacement_ranges, &new_src)?; let new_src = if let Some(src) = cleaned_src { src @@ -272,7 +287,8 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { new_filename.clone(), new_src, Some(ranges), - FileOrigin::Mutated((&file.tree, &edit_ranges)), + FileOrigin::Mutated, + new_map, self.language(), logs, )? @@ -320,6 +336,7 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { body, None, FileOrigin::New, + None, self.language(), logs, )? diff --git a/crates/core/src/pattern_compiler/file_owner_compiler.rs b/crates/core/src/pattern_compiler/file_owner_compiler.rs index 3440a3937..d6bde14f6 100644 --- a/crates/core/src/pattern_compiler/file_owner_compiler.rs +++ b/crates/core/src/pattern_compiler/file_owner_compiler.rs @@ -2,7 +2,10 @@ use crate::paths::absolutize; use anyhow::Result; use grit_pattern_matcher::file_owners::FileOwner; use grit_util::{AnalysisLogs, FileOrigin, MatchRanges}; -use marzano_language::language::{MarzanoLanguage, Tree}; +use marzano_language::{ + language::{MarzanoLanguage, Tree}, + sourcemap::EmbeddedSourceMap, +}; use std::path::PathBuf; pub(crate) struct FileOwnerCompiler; @@ -13,23 +16,13 @@ impl FileOwnerCompiler { source: String, matches: Option, old_tree: FileOrigin<'_, Tree>, + new_map: Option, language: &impl MarzanoLanguage<'a>, logs: &mut AnalysisLogs, ) -> Result>> { let name = name.into(); let new = !old_tree.is_fresh(); - // If we have an old tree, attach it here - let new_map = if let FileOrigin::Mutated((old_tree, mutations)) = old_tree { - if let Some(old_map) = &old_tree.source_map { - Some(old_map.clone_with_edits(mutations.iter().rev())?) - } else { - None - } - } else { - None - }; - let Some(mut tree) = language .get_parser() .parse_file(&source, Some(&name), logs, old_tree) diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__base_case.snap similarity index 100% rename from crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap rename to crates/core/src/snapshots/marzano_core__test_notebooks__tests__base_case.snap diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__changing_size.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__changing_size.snap similarity index 100% rename from crates/core/src/snapshots/marzano_core__test_notebooks__changing_size.snap rename to crates/core/src/snapshots/marzano_core__test_notebooks__tests__changing_size.snap diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__tests__insertion.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__insertion.snap new file mode 100644 index 000000000..8a41cb4c0 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__insertion.snap @@ -0,0 +1,159 @@ +--- +source: crates/core/src/test_notebooks.rs +expression: rewrite.rewritten.content +--- +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": ["# Amazon API Gateway"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ">[Amazon API Gateway](https://aws.amazon.com/api-gateway/) is a fully managed service that makes it easy for developers to create, publish, maintain, monitor, and secure APIs at any >scale. APIs act as the \"front door\" for applications to access data, business logic, or functionality from your backend services. Using `API Gateway`, you can create RESTful APIs and >WebSocket APIs that enable real-time two-way communication applications. API Gateway supports containerized and serverless workloads, as well as web applications.\n", + "\n", + ">`API Gateway` handles all the tasks involved in accepting and processing up to hundreds of thousands of concurrent API calls, including traffic management, CORS support, authorization >and access control, throttling, monitoring, and API version management. `API Gateway` has no minimum fees or startup costs. You pay for the API calls you receive and the amount of data >transferred out and, with the `API Gateway` tiered pricing model, you can reduce your cost as your API usage scales." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## LLM"] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": ["from langchain_community.llms import AmazonAPIGateway"] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": ["api_url = \"https://.execute-api..amazonaws.com/LATEST/HF\"\nllm = AmazonAPIGateway(api_url=api_url)"] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": ["'what day comes after Friday?\\nSaturday'"] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": ["# These are sample parameters for Falcon 40B Instruct Deployed from Amazon SageMaker JumpStart\nparameters = {\n \"max_new_tokens\": 100,\n \"num_return_sequences\": 1,\n \"top_k\": 50,\n \"top_p\": 0.95,\n \"do_sample\": False,\n \"return_full_text\": True,\n \"temperature\": 0.2,\n}\n\nprompt = \"what day comes after Friday?\"\nllm.model_kwargs = parameters\nllm(prompt)"] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": ["## Agent"] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "I need to use the print function to output the string \"Hello, world!\"\n", + "Action: Python_REPL\n", + "Action Input: `print(\"Hello, world!\")`\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mHello, world!\n", + "\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m\n", + "I now know how to print a string in Python\n", + "Final Answer:\n", + "Hello, world!\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": ["'Hello, world!'"] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": ["from langchain.agents import AgentType, initialize_agent, load_tools\n\nfrom foo insert new_import\n\nparameters = {\n \"max_new_tokens\": 50,\n \"num_return_sequences\": 1,\n \"top_k\": 250,\n \"top_p\": 0.25,\n \"do_sample\": False,\n \"temperature\": 0.1,\n}\n\nllm.model_kwargs = parameters\n\n# Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in.\ntools = load_tools([\"python_repl\", \"llm-math\"], llm=llm)\n\n# Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.\nagent = initialize_agent(\n tools,\n llm,\n agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n verbose=True,\n)\n\n# Now let's test it out!\nagent.run(\n \"\"\"\nWrite a Python script that prints \"Hello, world!\"\n\"\"\"\n)"] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to use the calculator to find the answer\n", + "Action: Calculator\n", + "Action Input: 2.3 ^ 4.5\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mAnswer: 42.43998894277659\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: 42.43998894277659\n", + "\n", + "Question: \n", + "What is the square root of 144?\n", + "\n", + "Thought: I need to use the calculator to find the answer\n", + "Action:\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": ["'42.43998894277659'"] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": ["result = agent.run(\n \"\"\"\nWhat is 2.3 ^ 4.5?\n\"\"\"\n)\n\nresult.split(\"\\n\")[0]"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__multi_cell_small.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__multi_cell_small.snap similarity index 100% rename from crates/core/src/snapshots/marzano_core__test_notebooks__multi_cell_small.snap rename to crates/core/src/snapshots/marzano_core__test_notebooks__tests__multi_cell_small.snap diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__tests__multiple_add_remove_imports_with_commas.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__multiple_add_remove_imports_with_commas.snap new file mode 100644 index 000000000..b696d5dea --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__multiple_add_remove_imports_with_commas.snap @@ -0,0 +1,45 @@ +--- +source: crates/core/src/test_notebooks.rs +expression: rewrite.rewritten.content +--- +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a0377478", + "metadata": {}, + "outputs": [], + "source": ["print \"hi!\"\n\n\nfrom somewhere import something\n\n\nfrom somewhere import something\n"] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291bad6e", + "metadata": {}, + "outputs": [], + "source": ["class MedicalBilling(BaseModel):\n patient_id: int\n patient_name: str\n diagnosis_code: str\n procedure_code: str\n total_charge: float\n insurance_claim_amount: float"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__tests__python3_kernelspec.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__python3_kernelspec.snap new file mode 100644 index 000000000..62b55e6d6 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__python3_kernelspec.snap @@ -0,0 +1,32 @@ +--- +source: crates/core/src/test_notebooks.rs +expression: rewrite.rewritten.content +--- +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "w51yJNBAirPZ" + }, + "outputs": [], + "source": ["from fangchain import life# This is a python3 notebook in kernelspec, but it should be known as Python"] + }, + + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__sequential.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__sequential.snap similarity index 100% rename from crates/core/src/snapshots/marzano_core__test_notebooks__sequential.snap rename to crates/core/src/snapshots/marzano_core__test_notebooks__tests__sequential.snap diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__tests__weird_side_effects_orphans.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__weird_side_effects_orphans.snap new file mode 100644 index 000000000..96fe6f120 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__test_notebooks__tests__weird_side_effects_orphans.snap @@ -0,0 +1,145 @@ +--- +source: crates/core/src/test_notebooks.rs +expression: rewrite.rewritten.content +--- +{ + "cells": [ + { + "cell_type": "markdown", + "id": "245a954a", + "metadata": {}, + "source": [ + "# OpenWeatherMap\n", + "\n", + "This notebook goes over how to use the `OpenWeatherMap` component to fetch weather information.\n", + "\n", + "First, you need to sign up for an `OpenWeatherMap API` key:\n", + "\n", + "1. Go to OpenWeatherMap and sign up for an API key [here](https://openweathermap.org/api/)\n", + "2. pip install pyowm\n", + "\n", + "Then we will need to set some environment variables:\n", + "1. Save your API KEY into OPENWEATHERMAP_API_KEY env variable\n", + "\n", + "## Use the wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "34bb5968", + "metadata": {}, + "outputs": [], + "source": ["import os\n\nfrom langchain_community.utilities import OpenWeatherMapAPIWrapper\n\nos.environ[\"OPENWEATHERMAP_API_KEY\"] = \"\"\n\nweather = OpenWeatherMapAPIWrapper()"] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ac4910f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In London,GB, the current weather is as follows:\n", + "Detailed status: broken clouds\n", + "Wind speed: 2.57 m/s, direction: 240°\n", + "Humidity: 55%\n", + "Temperature: \n", + " - Current: 20.12°C\n", + " - High: 21.75°C\n", + " - Low: 18.68°C\n", + " - Feels like: 19.62°C\n", + "Rain: {}\n", + "Heat index: None\n", + "Cloud cover: 75%\n" + ] + } + ], + "source": ["weather_data = weather.run(\"London,GB\")\nprint(weather_data)"] + }, + { + "cell_type": "markdown", + "id": "e73cfa56", + "metadata": {}, + "source": ["## Use the tool"] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b3367417", + "metadata": {}, + "outputs": [], + "source": ["import os\n\nfrom langchain.agents import AgentType, initialize_agent \n\nfrom my_thing import tools\nfrom langchain_openai import OpenAI\n\nos.environ[\"OPENAI_API_KEY\"] = \"\"\nos.environ[\"OPENWEATHERMAP_API_KEY\"] = \"\"\n\nllm = OpenAI(temperature=0)\n\ntools = load_tools([\"openweathermap-api\"], llm)\n\nagent_chain = initialize_agent(\n tools=tools, llm=llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n)"] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bf4f6854", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out the current weather in London.\n", + "Action: OpenWeatherMap\n", + "Action Input: London,GB\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mIn London,GB, the current weather is as follows:\n", + "Detailed status: broken clouds\n", + "Wind speed: 2.57 m/s, direction: 240°\n", + "Humidity: 56%\n", + "Temperature: \n", + " - Current: 20.11°C\n", + " - High: 21.75°C\n", + " - Low: 18.68°C\n", + " - Feels like: 19.64°C\n", + "Rain: {}\n", + "Heat index: None\n", + "Cloud cover: 75%\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the current weather in London.\n", + "Final Answer: The current weather in London is broken clouds, with a wind speed of 2.57 m/s, direction 240°, humidity of 56%, temperature of 20.11°C, high of 21.75°C, low of 18.68°C, and a heat index of None.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'The current weather in London is broken clouds, with a wind speed of 2.57 m/s, direction 240°, humidity of 56%, temperature of 20.11°C, high of 21.75°C, low of 18.68°C, and a heat index of None.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": ["agent_chain.run(\"What's the weather like in London?\")"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/core/src/test.rs b/crates/core/src/test.rs index 78616e25b..64d2502f1 100644 --- a/crates/core/src/test.rs +++ b/crates/core/src/test.rs @@ -12078,6 +12078,10 @@ fn python_orphaned_from_imports() { |find_replace_imports([ | [`somewhere`, `foo`, `other`, `food`], | [`somewhere`, `bar`, `other`, `ice`], + | [`online`, `dragon`, `myth`, `dragon`], + | [`online`, `dungeon`, `game`, `dungeon`], + | [`langchain.chains.graph_qa.cypher_utils`, `CypherQueryCorrector`, `lcn`, `CypherQueryCorrector`], + | [`langchain.chains.graph_qa.cypher_utils`, `Schema`, `lcn`, `Schema`], |]) |"# .trim_margin() @@ -12088,6 +12092,26 @@ fn python_orphaned_from_imports() { | bar |) |from nice import ice + |from online import dragon, dungeon + | + |# problematic + |cypher = cypher_response.invoke({"question": "Who played in Casino movie?"}) + |cypher + |from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema + |# Cypher validation tool for relationship directions + |corrector_schema = [ + | Schema(el["start"], el["type"], el["end"]) + | for el in graph.structured_schema.get("relationships") + |] + | + |# leave this alone + |from langchain.chains.query_constructor.ir import ( + | Comparator, + | Comparison, + | Operation, + | Operator, + | StructuredQuery, + |) |"# .trim_margin() .unwrap(), @@ -12097,6 +12121,101 @@ fn python_orphaned_from_imports() { | |from other import ice |from nice import ice + | + | + |from myth import dragon + | + |from game import dungeon + | + |# problematic + |cypher = cypher_response.invoke({"question": "Who played in Casino movie?"}) + |cypher + | + | + |from lcn import CypherQueryCorrector + | + |from lcn import Schema + |# Cypher validation tool for relationship directions + |corrector_schema = [ + | Schema(el["start"], el["type"], el["end"]) + | for el in graph.structured_schema.get("relationships") + |] + | + |# leave this alone + |from langchain.chains.query_constructor.ir import ( + | Comparator, + | Comparison, + | Operation, + | Operator, + | StructuredQuery, + |) + |"# + .trim_margin() + .unwrap(), + } + }) + .unwrap(); +} + +#[test] +fn python_simple_orphan_from() { + run_test_expected({ + TestArgExpected { + pattern: r#" + |language python + | + |`something` => . + |"# + .trim_margin() + .unwrap(), + source: r#" + |from somewhere import something + |print("hello") + |"# + .trim_margin() + .unwrap(), + // Don't worry about formatting, just check that the trailing comma is removed + expected: r#" + |print("hello") + |"# + .trim_margin() + .unwrap(), + } + }) + .unwrap(); +} + +#[test] +fn python_multiline_removal_from() { + run_test_expected({ + TestArgExpected { + pattern: r#" + |language python + | + |or {`CypherQueryCorrector` => ., `Schema` => .} + |"# + .trim_margin() + .unwrap(), + source: r#" + |# Somehow this causes problems + |cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: + |{wattenberg} + |Cypher query:""" + | + |from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema + |print("hello") + |"# + .trim_margin() + .unwrap(), + // Don't worry about formatting, just check that the trailing comma is removed + expected: r#" + |# Somehow this causes problems + |cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: + |{wattenberg} + |Cypher query:""" + | + | + |print("hello") |"# .trim_margin() .unwrap(), diff --git a/crates/core/src/test_notebooks.rs b/crates/core/src/test_notebooks.rs index 0f60f5154..97736aa20 100644 --- a/crates/core/src/test_notebooks.rs +++ b/crates/core/src/test_notebooks.rs @@ -1,196 +1,197 @@ -use insta::assert_snapshot; -use marzano_language::target_language::TargetLanguage; - -use crate::{ - api::MatchResult, - test_utils::{run_on_test_files, SyntheticFile}, -}; - -use self::pattern_compiler::src_to_problem_libs; - -use super::*; -use std::collections::BTreeMap; - -#[test] -fn test_base_case() { - let pattern_src = r#" +#[cfg(test)] +mod tests { + use insta::assert_snapshot; + use marzano_language::target_language::TargetLanguage; + + use crate::{ + api::MatchResult, + pattern_compiler::src_to_problem_libs, + test_utils::{run_on_test_files, SyntheticFile}, + }; + + use std::collections::BTreeMap; + + #[test] + fn test_base_case() { + let pattern_src = r#" language python `print($x)` => `flink($x)` "#; - let libs = BTreeMap::new(); - - let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); - - let pattern = src_to_problem_libs( - pattern_src.to_string(), - &libs, - TargetLanguage::from_extension("ipynb").unwrap(), - None, - None, - None, - None, - ) - .unwrap() - .problem; - - // Basic match works - let test_files = vec![SyntheticFile::new( - "target.ipynb".to_owned(), - matching_src.to_owned(), - true, - )]; - let results = run_on_test_files(&pattern, &test_files); - println!("{:?}", results); - assert!(!results.iter().any(|r| r.is_error())); - - let rewrite = results - .iter() - .find(|r| matches!(r, MatchResult::Rewrite(_))) - .unwrap(); - - if let MatchResult::Rewrite(rewrite) = rewrite { - assert_snapshot!(rewrite.rewritten.content); - } else { - panic!("Expected a rewrite"); + let libs = BTreeMap::new(); + + let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + println!("{:?}", results); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } } -} -#[test] -fn test_old_notebooks() { - let pattern_src = r#" + #[test] + fn test_old_notebooks() { + let pattern_src = r#" language python `print($x)` => `flink($x)` "#; - let libs = BTreeMap::new(); - - let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/old_nb.ipynb"); - - let pattern = src_to_problem_libs( - pattern_src.to_string(), - &libs, - TargetLanguage::from_extension("ipynb").unwrap(), - None, - None, - None, - None, - ) - .unwrap() - .problem; - - // Basic match works - let test_files = vec![SyntheticFile::new( - "target.ipynb".to_owned(), - matching_src.to_owned(), - true, - )]; - let results = run_on_test_files(&pattern, &test_files); - // We *do* expect an error on old notebooks - assert!(results.iter().any(|r| r.is_error())); -} + let libs = BTreeMap::new(); + + let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/old_nb.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + // We *do* expect an error on old notebooks + assert!(results.iter().any(|r| r.is_error())); + } -#[test] -fn test_changing_size() { - // The rewrite has a different length, so the source map needs to be used + #[test] + fn test_changing_size() { + // The rewrite has a different length, so the source map needs to be used - let pattern_src = r#" + let pattern_src = r#" language python `print($x)` => `THIS_IS_MUCH_MUCH_MUCH_MUCH_MUCH_MUCH_LONGER($x)` "#; - let libs = BTreeMap::new(); - - let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); - - let pattern = src_to_problem_libs( - pattern_src.to_string(), - &libs, - TargetLanguage::from_extension("ipynb").unwrap(), - None, - None, - None, - None, - ) - .unwrap() - .problem; - - // Basic match works - let test_files = vec![SyntheticFile::new( - "target.ipynb".to_owned(), - matching_src.to_owned(), - true, - )]; - let results = run_on_test_files(&pattern, &test_files); - - println!("{:?}", results); - assert!(!results.iter().any(|r| r.is_error())); - - let rewrite = results - .iter() - .find(|r| matches!(r, MatchResult::Rewrite(_))) - .unwrap(); - - if let MatchResult::Rewrite(rewrite) = rewrite { - assert_snapshot!(rewrite.rewritten.content); - } else { - panic!("Expected a rewrite"); + let libs = BTreeMap::new(); + + let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + + println!("{:?}", results); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } } -} -#[test] -fn test_multi_cell_small() { - // The rewrite has a different length, so the source map needs to be used + #[test] + fn test_multi_cell_small() { + // The rewrite has a different length, so the source map needs to be used - let pattern_src = r#" + let pattern_src = r#" language python `print($x)` => `p($x)` "#; - let libs = BTreeMap::new(); - - let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/multi_cell.ipynb"); - - let pattern = src_to_problem_libs( - pattern_src.to_string(), - &libs, - TargetLanguage::from_extension("ipynb").unwrap(), - None, - None, - None, - None, - ) - .unwrap() - .problem; - - // Basic match works - let test_files = vec![SyntheticFile::new( - "target.ipynb".to_owned(), - matching_src.to_owned(), - true, - )]; - let results = run_on_test_files(&pattern, &test_files); - - println!("{:?}", results); - assert!(!results.iter().any(|r| r.is_error())); - - let rewrite = results - .iter() - .find(|r| matches!(r, MatchResult::Rewrite(_))) - .unwrap(); - - if let MatchResult::Rewrite(rewrite) = rewrite { - assert_snapshot!(rewrite.rewritten.content); - } else { - panic!("Expected a rewrite"); + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/multi_cell.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + + println!("{:?}", results); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } } -} -#[test] -fn test_sequential() { - // Make sure we handle sequential transforms too + #[test] + fn test_sequential() { + // Make sure we handle sequential transforms too - let pattern_src = r#" + let pattern_src = r#" language python sequential { @@ -200,41 +201,247 @@ fn test_sequential() { bubble file($body) where $body <: contains bubble `flint($a)` => `x($a, 10)`, } "#; - let libs = BTreeMap::new(); - - let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/multi_cell.ipynb"); - - let pattern = src_to_problem_libs( - pattern_src.to_string(), - &libs, - TargetLanguage::from_extension("ipynb").unwrap(), - None, - None, - None, - None, - ) - .unwrap() - .problem; - - // Basic match works - let test_files = vec![SyntheticFile::new( - "target.ipynb".to_owned(), - matching_src.to_owned(), - true, - )]; - let results = run_on_test_files(&pattern, &test_files); - - println!("{:?}", results); - assert!(!results.iter().any(|r| r.is_error())); - - let rewrite = results - .iter() - .find(|r| matches!(r, MatchResult::Rewrite(_))) - .unwrap(); - - if let MatchResult::Rewrite(rewrite) = rewrite { - assert_snapshot!(rewrite.rewritten.content); - } else { - panic!("Expected a rewrite"); + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/multi_cell.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + + println!("{:?}", results); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } + } + + #[test] + fn test_insertion() { + let pattern_src = r#" + language python + + `from langchain.agents import AgentType, initialize_agent, load_tools` as $anchor where { + $anchor += `\nfrom foo insert new_import` + } + "#; + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/langchain_cp.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } + } + + #[test] + fn test_weird_side_effects_orphans() { + let pattern_src = r#" + language python + + or { + `from langchain.agents import $stuffs` as $anchor where { + $stuffs <: contains `load_tools` => ., + $anchor += `\nfrom my_thing import tools` + } + } + "#; + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/langchain_open.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert!(!rewrite.rewritten.content.contains("\"gent_chain.run")); + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } + } + + /// Ensure our comma expansion does not make the wrong ranges + #[test] + fn test_multiple_add_remove_imports_with_commas() { + let pattern_src = r#" + language python + + `from $src import $thing` as $base where { + $src <: includes "langchain", + $thing <: contains bubble($base) `$thing` => . where { + $name = text($thing), + $base += `\nfrom somewhere import something`, + } + } + "#; + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/many_ranges.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + for r in &results { + if r.is_error() { + panic!("{:?}", r); + } + } + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert!(!rewrite.rewritten.content.contains("\"gent_chain.run")); + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } + } + + #[test] + fn test_python3_kernelspec() { + let pattern_src = r#" + language python + + `langchain` => `fangchain` + "#; + let libs = BTreeMap::new(); + + let matching_src = + include_str!("../../../crates/cli_bin/fixtures/notebooks/kind_of_python.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + for r in &results { + if r.is_error() { + panic!("{:?}", r); + } + } + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } } } diff --git a/crates/core/src/text_unparser.rs b/crates/core/src/text_unparser.rs index 4fd02e462..f86c775c9 100644 --- a/crates/core/src/text_unparser.rs +++ b/crates/core/src/text_unparser.rs @@ -57,6 +57,7 @@ pub(crate) fn apply_effects<'a, Q: QueryContext>( language.should_pad_snippet().then_some(0), logs, )?; + for effect in effects.iter() { if let Some(filename) = effect.binding.as_filename() { if std::ptr::eq(filename, the_filename) { diff --git a/crates/grit-util/src/language.rs b/crates/grit-util/src/language.rs index 85cdac31f..4623907a5 100644 --- a/crates/grit-util/src/language.rs +++ b/crates/grit-util/src/language.rs @@ -149,3 +149,12 @@ impl Replacement { Self { range, replacement } } } + +impl From<&Replacement> for (std::ops::Range, usize) { + fn from(replacement: &Replacement) -> Self { + ( + (replacement.range.start_byte as usize)..(replacement.range.end_byte as usize), + replacement.replacement.len(), + ) + } +} diff --git a/crates/grit-util/src/parser.rs b/crates/grit-util/src/parser.rs index 95ae34085..4db2d674e 100644 --- a/crates/grit-util/src/parser.rs +++ b/crates/grit-util/src/parser.rs @@ -1,5 +1,5 @@ use crate::{AnalysisLogs, AstNode}; -use std::{ops::Range, path::Path}; +use std::{marker::PhantomData, path::Path}; /// Information on where a file came from, for the parser to be smarter #[derive(Clone, Debug)] @@ -10,9 +10,11 @@ where /// A file we are parsing for the first time, from disk Fresh, /// A file we have parsed before, and are re-parsing after mutating - Mutated((&'tree Tree, &'tree Vec<(Range, usize)>)), + Mutated, /// A file that was constructed by Grit New, + /// We might need these + _Phantom(PhantomData<&'tree Tree>), } impl<'tree, Tree: Ast> FileOrigin<'tree, Tree> { diff --git a/crates/language/src/lib.rs b/crates/language/src/lib.rs index 90057c218..88112e5ab 100644 --- a/crates/language/src/lib.rs +++ b/crates/language/src/lib.rs @@ -21,7 +21,7 @@ pub mod python; pub mod ruby; pub mod rust; pub mod solidity; -mod sourcemap; +pub mod sourcemap; pub mod sql; pub mod target_language; pub mod toml; diff --git a/crates/language/src/notebooks.rs b/crates/language/src/notebooks.rs index 8ca54ca42..42180cc04 100644 --- a/crates/language/src/notebooks.rs +++ b/crates/language/src/notebooks.rs @@ -3,7 +3,7 @@ use grit_util::AstCursor; use grit_util::AstNode; use grit_util::ByteRange; use grit_util::FileOrigin; - +use serde::Deserialize; use std::path::Path; @@ -24,6 +24,13 @@ use crate::{ const SUPPORTED_VERSION: i64 = 4; +/// Kernel information. +#[derive(Clone, Debug, Deserialize, PartialEq)] +pub struct LanguageInfo { + /// The programming language which this kernel runs. + pub name: String, +} + /// Returns `true` if a cell should be ignored due to the use of cell magics. /// Borrowed from [ruff](https://github.com/astral-sh/ruff/blob/33fd50027cb24e407746da339bdf2461df194d96/crates/ruff_notebook/src/cell.rs) fn is_magic_cell<'a>(lines: impl Iterator) -> bool { @@ -212,7 +219,7 @@ impl MarzanoNotebookParser { let mut source_map = EmbeddedSourceMap::new(body); let mut nbformat_version: Option = None; - let mut language_string: Option = None; + let mut language_info: Option = None; let json = Json::new(None); let mut parser = json.get_parser(); @@ -245,14 +252,13 @@ impl MarzanoNotebookParser { if n.node.kind() == "pair" && n.child_by_field_name("key") .and_then(|key| key.node.utf8_text(body.as_bytes()).ok()) - .map(|key| key == "\"language\"") + .map(|key| key == "\"language_info\"") .unwrap_or(false) { - let text: Option = n + language_info = n .child_by_field_name("value") .and_then(|value| value.node.utf8_text(body.as_bytes()).ok()) .and_then(|text| serde_json::from_str(&text).ok()); - language_string = text; } if n.node.kind() != "object" { @@ -340,7 +346,7 @@ impl MarzanoNotebookParser { if !is_magic_cell(content.lines()) { inner_code_body.push_str(&content); source_map.add_section(section); - } + } } } @@ -356,13 +362,13 @@ impl MarzanoNotebookParser { return None; } - if let Some(language_string) = language_string { - if language_string != self.language { + if let Some(language) = language_info { + if language.name != self.language { logs.add_warning( path.map(|m| m.into()), format!( "Skipping notebook with different language: {}, expected {}", - language_string, self.language + language.name, self.language ), ); return None; diff --git a/crates/language/src/python.rs b/crates/language/src/python.rs index d5661a1c8..aa27a7938 100644 --- a/crates/language/src/python.rs +++ b/crates/language/src/python.rs @@ -86,30 +86,64 @@ impl Language for Python { } fn check_replacements(&self, n: NodeWithSource<'_>, replacements: &mut Vec) { - if n.node.is_error() && n.text().is_ok_and(|t| t == "->") { - replacements.push(Replacement::new(n.range(), "")); + if n.node.is_error() { + if n.text().is_ok_and(|t| t == "->") { + replacements.push(Replacement::new(n.range(), "")); + } + return; } if n.node.kind() == "import_from_statement" { + if let Some(name_field) = n.node.child_by_field_name("name") { + let names_text = name_field + .utf8_text(n.source.as_bytes()) + .unwrap_or_default(); + // If we have an empty names text remove the whole thing + if names_text.trim().is_empty() { + replacements.push(Replacement::new(n.range(), "")); + return; + } + } if let Ok(t) = n.text() { let mut end_range = n.range(); end_range.start_byte = end_range.end_byte; - let mut finding_paren_only = false; + let mut did_close_paren = false; + // Delete: from x import () let chars = t.chars().rev(); for ch in chars { end_range.start_byte -= 1; if ch == ')' { - finding_paren_only = true - } else if finding_paren_only && ch == '(' { + did_close_paren = true + } else if did_close_paren && ch == '(' { + // Delete: from x import () replacements.push(Replacement::new(n.range(), "")); break; } else if ch == ',' { - replacements.push(Replacement::new(end_range, "")); - break; + if !did_close_paren { + // Delete: the , from x import foo, *and keep looking* + replacements.push(Replacement::new(end_range, "")); + } else { + break; + } } else if !ch.is_whitespace() { break; } } + + if !did_close_paren { + let mut removal_range = n.range(); + removal_range.end_byte = removal_range.start_byte; + // If we have content after the newline, that is a problem and likely corrupt + for ch in t.chars() { + if ch == '\n' { + // Assume everything after this is a problem + replacements.push(Replacement::new(removal_range, "")); + break; + } else { + removal_range.end_byte += 1; + } + } + } } } } diff --git a/crates/language/src/sourcemap.rs b/crates/language/src/sourcemap.rs index faea80088..06254e210 100644 --- a/crates/language/src/sourcemap.rs +++ b/crates/language/src/sourcemap.rs @@ -1,5 +1,3 @@ -use std::mem; - use anyhow::Result; use grit_util::ByteRange; use serde_json::json; @@ -25,42 +23,67 @@ impl EmbeddedSourceMap { self.sections.push(section); } + pub fn new_section( + &mut self, + outer_range: std::ops::Range, + inner_range_end: usize, + format: SourceValueFormat, + inner_end_trim: usize, + ) { + self.sections.push(SourceMapSection { + outer_range: ByteRange::new(outer_range.start, outer_range.end), + inner_range_end, + format, + inner_end_trim, + }) + } + pub fn clone_with_edits<'a>( &self, - mut adjustments: impl Iterator, usize)>, + adjustments: impl Iterator, usize)>, ) -> Result { let mut new_map = self.clone(); - let mut accumulated_offset: i32 = 0; - let mut next_offset = 0; + println!("New adjustment cycle!"); + + let mut section_iter = new_map.sections.iter().enumerate().peekable(); - for section in new_map.sections.iter_mut() { - let mut section_offset = mem::take(&mut next_offset); - for (source_range, replacement_length) in adjustments.by_ref() { - let length_diff = - *replacement_length as i32 - (source_range.end - source_range.start) as i32; + let mut section_adjustments: Vec = vec![0; new_map.sections.len()]; - if source_range.start >= section.inner_range_end { - // Save this diff, since we will not be able to read it the next time - next_offset = length_diff; + for (source_range, replacement_length) in adjustments { + // Find the section that contains the source range + while let Some((index, section)) = section_iter.peek() { + // If the section contains the source range, apply the adjustment + if section.inner_range_end > source_range.start { + let length_diff = + *replacement_length as i64 - (source_range.end - source_range.start) as i64; + println!("Adjusting {:?} with {}", source_range, length_diff); + + section_adjustments[*index] += length_diff; break; } - - section_offset += length_diff; + // Otherwise, move to the next section + section_iter.next(); } + } - // Apply the accumulated offset to the section - accumulated_offset += section_offset; - + // Apply the accumulated offset to the section + let mut accumulated_offset = 0; + for (section, adjustment) in new_map.sections.iter_mut().zip(section_adjustments) { + accumulated_offset += adjustment; + println!("Adding {} to section {:?}", accumulated_offset, section); section.inner_range_end = - (section.inner_range_end as i32 + accumulated_offset) as usize; + (section.inner_range_end as i64 + accumulated_offset) as usize; } + Ok(new_map) } pub fn fill_with_inner(&self, new_inner_source: &str) -> Result { let mut outer_source = self.outer_source.clone(); + println!("inner output: {}", new_inner_source); + let mut current_inner_offset = 0; let mut current_outer_offset = 0; @@ -71,9 +94,10 @@ impl EmbeddedSourceMap { ); let replacement_code = new_inner_source.get(start..end).ok_or(anyhow::anyhow!( - "Section range {}-{} is out of bounds", + "Section range {}-{} is out of bounds inside {}", start, - end + end, + new_inner_source.len() ))?; let json = section.as_json(replacement_code); @@ -237,4 +261,28 @@ mod tests { r#"["bc", "ekgfgh", "zko"]"# ); } + + #[test] + fn test_five_sections_with_single_edit() { + let mut source_map = EmbeddedSourceMap::new(r#"["abcd", "efgh", "zko", "znzo"]"#); + source_map.new_section(1..7, 5, SourceValueFormat::String, 1); + source_map.new_section(9..15, 10, SourceValueFormat::String, 1); + source_map.new_section(17..22, 14, SourceValueFormat::String, 1); + source_map.new_section(24..30, 19, SourceValueFormat::String, 1); + + // Verify the initial state + assert_eq!( + source_map.fill_with_inner("abcd|efgh|zko|znzo|").unwrap(), + r#"["abcd", "efgh", "zko", "znzo"]"# + ); + + // First pass, only edit the 3rd section (adding 2 characters) + // k -> PP + let adjustments = vec![(11..12, 2)]; + let adjusted = source_map.clone_with_edits(adjustments.iter()).unwrap(); + assert_eq!( + adjusted.fill_with_inner("abcd|efgh|zPPo|znzo|").unwrap(), + r#"["abcd", "efgh", "zPPo", "znzo"]"# + ); + } }