From 0c8cb69bcce9c39dc94f8272accf7d87bc872978 Mon Sep 17 00:00:00 2001 From: Morgante Pell Date: Sat, 25 May 2024 14:55:15 -0700 Subject: [PATCH] feat: basic ipython support (#353) --- crates/cli/Cargo.toml | 2 +- crates/cli/src/commands/patterns_list.rs | 2 +- crates/cli/src/github.rs | 2 +- crates/cli/src/result_formatting.rs | 2 +- .../cli_bin/fixtures/notebooks/old_nb.ipynb | 35 + .../cli_bin/fixtures/notebooks/other_nb.ipynb | 40 + .../cli_bin/fixtures/notebooks/pattern.grit | 11 + .../cli_bin/fixtures/notebooks/tiny_nb.ipynb | 39 + crates/cli_bin/fixtures/wikibase_agent.ipynb | 765 ++++++++++++++++++ crates/cli_bin/tests/apply.rs | 42 + .../snapshots/apply__python_in_notebook.snap | 39 + crates/core/src/api.rs | 27 +- crates/core/src/effects_dependency_graph.rs | 464 ----------- crates/core/src/lib.rs | 3 +- crates/core/src/marzano_context.rs | 18 +- crates/core/src/parse.rs | 65 +- .../pattern_compiler/file_owner_compiler.rs | 23 +- crates/core/src/problem.rs | 2 +- ...no_core__parse__tests__other_notebook.snap | 32 + ...o_core__parse__tests__simple_notebook.snap | 26 + ...core__parse__tests__verify_notebook-2.snap | 32 + ...o_core__parse__tests__verify_notebook.snap | 26 + ...rzano_core__test_notebooks__base_case.snap | 39 + crates/core/src/test.rs | 2 +- crates/core/src/test_files.rs | 1 - crates/core/src/test_notebooks.rs | 91 +++ crates/grit-util/src/analysis_logs.rs | 9 + crates/grit-util/src/lib.rs | 2 +- crates/grit-util/src/parser.rs | 32 +- crates/gritmodule/src/fetcher.rs | 2 +- crates/gritmodule/src/markdown.rs | 2 +- crates/language/src/css.rs | 8 +- crates/language/src/js_like.rs | 8 +- crates/language/src/language.rs | 20 +- crates/language/src/lib.rs | 2 + crates/language/src/notebooks.rs | 242 ++++++ crates/language/src/python.rs | 11 +- crates/language/src/sourcemap.rs | 69 ++ crates/language/src/target_language.rs | 6 +- crates/language/src/vue.rs | 1 + crates/wasm-bindings/src/match_pattern.rs | 8 +- 41 files changed, 1738 insertions(+), 514 deletions(-) create mode 100644 crates/cli_bin/fixtures/notebooks/old_nb.ipynb create mode 100644 crates/cli_bin/fixtures/notebooks/other_nb.ipynb create mode 100644 crates/cli_bin/fixtures/notebooks/pattern.grit create mode 100644 crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb create mode 100644 crates/cli_bin/fixtures/wikibase_agent.ipynb create mode 100644 crates/cli_bin/tests/snapshots/apply__python_in_notebook.snap delete mode 100644 crates/core/src/effects_dependency_graph.rs create mode 100644 crates/core/src/snapshots/marzano_core__parse__tests__other_notebook.snap create mode 100644 crates/core/src/snapshots/marzano_core__parse__tests__simple_notebook.snap create mode 100644 crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook-2.snap create mode 100644 crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook.snap create mode 100644 crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap create mode 100644 crates/core/src/test_notebooks.rs create mode 100644 crates/language/src/notebooks.rs create mode 100644 crates/language/src/sourcemap.rs diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 8fff8a6f8..9d1612a41 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -26,7 +26,7 @@ uuid = { version = "1.1", features = ["v4", "serde"] } tokio = { version = "1", features = ["full"] } chrono = { version = "0.4.26", features = ["serde"] } reqwest = { version = "0.11", features = ["json", "stream"] } -futures-util = "0.3.30" +futures-util = { version = "0.3.30" } lazy_static = { version = "1.4.0" } indicatif-log-bridge = { version = "0.2.1" } colored = { version = "2.0.4" } diff --git a/crates/cli/src/commands/patterns_list.rs b/crates/cli/src/commands/patterns_list.rs index 02d65f897..73f68ec1f 100644 --- a/crates/cli/src/commands/patterns_list.rs +++ b/crates/cli/src/commands/patterns_list.rs @@ -5,7 +5,7 @@ use marzano_gritmodule::config::{DefinitionSource, ResolvedGritDefinition}; use crate::{ flags::GlobalFormatFlags, lister::{list_applyables, Listable}, - resolver::{resolve_from_flags_or_cwd}, + resolver::resolve_from_flags_or_cwd, }; use super::list::ListArgs; diff --git a/crates/cli/src/github.rs b/crates/cli/src/github.rs index a083512b5..3cf6b682b 100644 --- a/crates/cli/src/github.rs +++ b/crates/cli/src/github.rs @@ -1,12 +1,12 @@ use crate::analyze::group_checks; use crate::ux::CheckResult; use anyhow::{Context as _, Result}; +use fs_err::OpenOptions; use grit_util::Range; use log::info; use marzano_core::{api::EnforcementLevel, fs::extract_ranges}; use marzano_gritmodule::config::ResolvedGritDefinition; use marzano_gritmodule::utils::extract_path; -use fs_err::OpenOptions; use std::io::prelude::*; fn format_level(level: &EnforcementLevel) -> String { diff --git a/crates/cli/src/result_formatting.rs b/crates/cli/src/result_formatting.rs index 1f0fb4c7b..7d9857e3b 100644 --- a/crates/cli/src/result_formatting.rs +++ b/crates/cli/src/result_formatting.rs @@ -2,6 +2,7 @@ use anyhow::anyhow; use colored::Colorize; use console::style; use core::fmt; +use fs_err::read_to_string; use log::info; use marzano_core::api::{ AllDone, AnalysisLog, CreateFile, DoneFile, FileMatchResult, InputFile, Match, MatchResult, @@ -10,7 +11,6 @@ use marzano_core::api::{ use marzano_core::constants::DEFAULT_FILE_NAME; use marzano_messenger::output_mode::OutputMode; use std::fmt::Display; -use fs_err::read_to_string; use std::{ io::Write, sync::{Arc, Mutex}, diff --git a/crates/cli_bin/fixtures/notebooks/old_nb.ipynb b/crates/cli_bin/fixtures/notebooks/old_nb.ipynb new file mode 100644 index 000000000..1a2508ce3 --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/old_nb.ipynb @@ -0,0 +1,35 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": ["print(4)\n", "print(\"3\")\n", "print(5)"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 2, + "nbformat_minor": 5 +} diff --git a/crates/cli_bin/fixtures/notebooks/other_nb.ipynb b/crates/cli_bin/fixtures/notebooks/other_nb.ipynb new file mode 100644 index 000000000..c441bd984 --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/other_nb.ipynb @@ -0,0 +1,40 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import configparser\n", + "\n", + "config = configparser.ConfigParser()\n", + "config.read(\"./secrets.ini\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/cli_bin/fixtures/notebooks/pattern.grit b/crates/cli_bin/fixtures/notebooks/pattern.grit new file mode 100644 index 000000000..bc0652e4f --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/pattern.grit @@ -0,0 +1,11 @@ +language python + +or { + // Simple replacement + `configparser.ConfigParser()` => `FigParser.Fig()`, + // Cross-cell references + `Tool($args)` => `Fool($args)` where { + $program <: contains `from langchain.agents import $_` + }, + `print($x)` => `flint($x)` +} diff --git a/crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb b/crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb new file mode 100644 index 000000000..214809cac --- /dev/null +++ b/crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb @@ -0,0 +1,39 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(4)\n", + "print(\"3\")\n", + "print(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/cli_bin/fixtures/wikibase_agent.ipynb b/crates/cli_bin/fixtures/wikibase_agent.ipynb new file mode 100644 index 000000000..8d9ef04dc --- /dev/null +++ b/crates/cli_bin/fixtures/wikibase_agent.ipynb @@ -0,0 +1,765 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "source": [ + "# Wikibase Agent\n", + "\n", + "This notebook demonstrates a very simple wikibase agent that uses sparql generation. Although this code is intended to work against any\n", + "wikibase instance, we use http://wikidata.org for testing.\n", + "\n", + "If you are interested in wikibases and sparql, please consider helping to improve this agent. Look [here](https://github.com/donaldziff/langchain-wikibase) for more details and open questions.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "07d42966-7e99-4157-90dc-6704977dcf1b", + "metadata": { + "tags": [] + }, + "source": ["## Preliminaries"] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9132f093-c61e-4b8d-abef-91ebef3fc85f", + "metadata": { + "tags": [] + }, + "source": [ + "### API keys and other secrets\n", + "\n", + "We use an `.ini` file, like this: \n", + "```\n", + "[OPENAI]\n", + "OPENAI_API_KEY=xyzzy\n", + "[WIKIDATA]\n", + "WIKIDATA_USER_AGENT_HEADER=argle-bargle\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "99567dfd-05a7-412f-abf0-9b9f4424acbd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": ["['./secrets.ini']"] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import configparser\n", + "\n", + "config = configparser.ConfigParser()\n", + "config.read(\"./secrets.ini\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "332b6658-c978-41ca-a2be-4f8677fecaef", + "metadata": { + "tags": [] + }, + "source": [ + "### OpenAI API Key\n", + "\n", + "An OpenAI API key is required unless you modify the code below to use another LLM provider." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd328ee2-33cc-4e1e-aff7-cc0a2e05e2e6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "openai_api_key = config[\"OPENAI\"][\"OPENAI_API_KEY\"]\n", + "import os\n", + "\n", + "os.environ.update({\"OPENAI_API_KEY\": openai_api_key})" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "42a9311b-600d-42bc-b000-2692ef87a213", + "metadata": { + "tags": [] + }, + "source": [ + "### Wikidata user-agent header\n", + "\n", + "Wikidata policy requires a user-agent header. See https://meta.wikimedia.org/wiki/User-Agent_policy. However, at present this policy is not strictly enforced." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "17ba657e-789d-40e1-b4b7-4f29ba06fe79", + "metadata": {}, + "outputs": [], + "source": [ + "wikidata_user_agent_header = (\n", + " None\n", + " if not config.has_section(\"WIKIDATA\")\n", + " else config[\"WIKIDATA\"][\"WIKIDATA_USER_AGENT_HEADER\"]\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "db08d308-050a-4fc8-93c9-8de4ae977ac3", + "metadata": {}, + "source": ["### Enable tracing if desired"] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77d2da08-fccd-4676-b77e-c0e89bf343cb", + "metadata": {}, + "outputs": [], + "source": [ + "# import os\n", + "# os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n", + "# os.environ[\"LANGCHAIN_SESSION\"] = \"default\" # Make sure this session actually exists." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3dbc5bfc-48ce-4f90-873c-7336b21300c6", + "metadata": {}, + "source": [ + "# Tools\n", + "\n", + "Three tools are provided for this simple agent:\n", + "* `ItemLookup`: for finding the q-number of an item\n", + "* `PropertyLookup`: for finding the p-number of a property\n", + "* `SparqlQueryRunner`: for running a sparql query" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f801b4e-6576-4914-aa4f-6f4c4e3c7924", + "metadata": { + "tags": [] + }, + "source": [ + "## Item and Property lookup\n", + "\n", + "Item and Property lookup are implemented in a single method, using an elastic search endpoint. Not all wikibase instances have it, but wikidata does, and that's where we'll start." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "42d23f0a-1c74-4c9c-85f2-d0e24204e96a", + "metadata": {}, + "outputs": [], + "source": [ + "def get_nested_value(o: dict, path: list) -> any:\n", + " current = o\n", + " for key in path:\n", + " try:\n", + " current = current[key]\n", + " except KeyError:\n", + " return None\n", + " return current\n", + "\n", + "\n", + "from typing import Optional\n", + "\n", + "import requests\n", + "\n", + "\n", + "def vocab_lookup(\n", + " search: str,\n", + " entity_type: str = \"item\",\n", + " url: str = \"https://www.wikidata.org/w/api.php\",\n", + " user_agent_header: str = wikidata_user_agent_header,\n", + " srqiprofile: str = None,\n", + ") -> Optional[str]:\n", + " headers = {\"Accept\": \"application/json\"}\n", + " if wikidata_user_agent_header is not None:\n", + " headers[\"User-Agent\"] = wikidata_user_agent_header\n", + "\n", + " if entity_type == \"item\":\n", + " srnamespace = 0\n", + " srqiprofile = \"classic_noboostlinks\" if srqiprofile is None else srqiprofile\n", + " elif entity_type == \"property\":\n", + " srnamespace = 120\n", + " srqiprofile = \"classic\" if srqiprofile is None else srqiprofile\n", + " else:\n", + " raise ValueError(\"entity_type must be either 'property' or 'item'\")\n", + "\n", + " params = {\n", + " \"action\": \"query\",\n", + " \"list\": \"search\",\n", + " \"srsearch\": search,\n", + " \"srnamespace\": srnamespace,\n", + " \"srlimit\": 1,\n", + " \"srqiprofile\": srqiprofile,\n", + " \"srwhat\": \"text\",\n", + " \"format\": \"json\",\n", + " }\n", + "\n", + " response = requests.get(url, headers=headers, params=params)\n", + "\n", + " if response.status_code == 200:\n", + " title = get_nested_value(response.json(), [\"query\", \"search\", 0, \"title\"])\n", + " if title is None:\n", + " return f\"I couldn't find any {entity_type} for '{search}'. Please rephrase your request and try again\"\n", + " # if there is a prefix, strip it off\n", + " return title.split(\":\")[-1]\n", + " else:\n", + " return \"Sorry, I got an error. Please try again.\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e52060fa-3614-43fb-894e-54e9b75d1e9f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["Q4180017\n"] + } + ], + "source": ["print(vocab_lookup(\"Malin 1\"))"] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b23ab322-b2cf-404e-b36f-2bfc1d79b0d3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": ["P31\n"] + } + ], + "source": [ + "print(vocab_lookup(\"instance of\", entity_type=\"property\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "89020cc8-104e-42d0-ac32-885e590de515", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I couldn't find any item for 'Ceci n'est pas un q-item'. Please rephrase your request and try again\n" + ] + } + ], + "source": ["print(vocab_lookup(\"Ceci n'est pas un q-item\"))"] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "78d66d8b-0e34-4d3f-a18d-c7284840ac76", + "metadata": {}, + "source": ["## Sparql runner "] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c6f60069-fbe0-4015-87fb-0e487cd914e7", + "metadata": {}, + "source": ["This tool runs sparql - by default, wikidata is used."] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b5b97a4d-2a39-4993-88d9-e7818c0a2853", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Any, Dict, List\n", + "\n", + "import requests\n", + "\n", + "\n", + "def run_sparql(\n", + " query: str,\n", + " url=\"https://query.wikidata.org/sparql\",\n", + " user_agent_header: str = wikidata_user_agent_header,\n", + ") -> List[Dict[str, Any]]:\n", + " headers = {\"Accept\": \"application/json\"}\n", + " if wikidata_user_agent_header is not None:\n", + " headers[\"User-Agent\"] = wikidata_user_agent_header\n", + "\n", + " response = requests.get(\n", + " url, headers=headers, params={\"query\": query, \"format\": \"json\"}\n", + " )\n", + "\n", + " if response.status_code != 200:\n", + " return \"That query failed. Perhaps you could try a different one?\"\n", + " results = get_nested_value(response.json(), [\"results\", \"bindings\"])\n", + " return json.dumps(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "149722ec-8bc1-4d4f-892b-e4ddbe8444c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[{\"count\": {\"datatype\": \"http://www.w3.org/2001/XMLSchema#integer\", \"type\": \"literal\", \"value\": \"20\"}}]'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "run_sparql(\"SELECT (COUNT(?children) as ?count) WHERE { wd:Q1339 wdt:P40 ?children . }\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "9f0302fd-ba35-4acc-ba32-1d7c9295c898", + "metadata": {}, + "source": ["# Agent"] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3122a961-9673-4a52-b1cd-7d62fbdf8d96", + "metadata": {}, + "source": ["## Wrap the tools"] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cc41ae88-2e53-4363-9878-28b26430cb1e", + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "from typing import List, Union\n", + "\n", + "from langchain.agents import (\n", + " AgentExecutor,\n", + " AgentOutputParser,\n", + " LLMSingleActionAgent,\n", + " Tool,\n", + ")\n", + "from langchain.chains import LLMChain\n", + "from langchain.prompts import StringPromptTemplate\n", + "from langchain_core.agents import AgentAction, AgentFinish" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2810a3ce-b9c6-47ee-8068-12ca967cd0ea", + "metadata": {}, + "outputs": [], + "source": [ + "# Define which tools the agent can use to answer user queries\n", + "tools = [\n", + " Tool(\n", + " name=\"ItemLookup\",\n", + " func=(lambda x: vocab_lookup(x, entity_type=\"item\")),\n", + " description=\"useful for when you need to know the q-number for an item\",\n", + " ),\n", + " Tool(\n", + " name=\"PropertyLookup\",\n", + " func=(lambda x: vocab_lookup(x, entity_type=\"property\")),\n", + " description=\"useful for when you need to know the p-number for a property\",\n", + " ),\n", + " Tool(\n", + " name=\"SparqlQueryRunner\",\n", + " func=run_sparql,\n", + " description=\"useful for getting results from a wikibase\",\n", + " ),\n", + "]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ab0f2778-a195-4a4a-a5b4-c1e809e1fb7b", + "metadata": {}, + "source": ["## Prompts"] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7bd4ba4f-57d6-4ceb-b932-3cb0d0509a24", + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the base template\n", + "template = \"\"\"\n", + "Answer the following questions by running a sparql query against a wikibase where the p and q items are \n", + "completely unknown to you. You will need to discover the p and q items before you can generate the sparql.\n", + "Do not assume you know the p and q items for any concepts. Always use tools to find all p and q items.\n", + "After you generate the sparql, you should run it. The results will be returned in json. \n", + "Summarize the json results in natural language.\n", + "\n", + "You may assume the following prefixes:\n", + "PREFIX wd: \n", + "PREFIX wdt: \n", + "PREFIX p: \n", + "PREFIX ps: \n", + "\n", + "When generating sparql:\n", + "* Try to avoid \"count\" and \"filter\" queries if possible\n", + "* Never enclose the sparql in back-quotes\n", + "\n", + "You have access to the following tools:\n", + "\n", + "{tools}\n", + "\n", + "Use the following format:\n", + "\n", + "Question: the input question for which you must provide a natural language answer\n", + "Thought: you should always think about what to do\n", + "Action: the action to take, should be one of [{tool_names}]\n", + "Action Input: the input to the action\n", + "Observation: the result of the action\n", + "... (this Thought/Action/Action Input/Observation can repeat N times)\n", + "Thought: I now know the final answer\n", + "Final Answer: the final answer to the original input question\n", + "\n", + "Question: {input}\n", + "{agent_scratchpad}\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "7e8d771a-64bb-4ec8-b472-6a9a40c6dd38", + "metadata": {}, + "outputs": [], + "source": [ + "# Set up a prompt template\n", + "class CustomPromptTemplate(StringPromptTemplate):\n", + " # The template to use\n", + " template: str\n", + " # The list of tools available\n", + " tools: List[Tool]\n", + "\n", + " def format(self, **kwargs) -> str:\n", + " # Get the intermediate steps (AgentAction, Observation tuples)\n", + " # Format them in a particular way\n", + " intermediate_steps = kwargs.pop(\"intermediate_steps\")\n", + " thoughts = \"\"\n", + " for action, observation in intermediate_steps:\n", + " thoughts += action.log\n", + " thoughts += f\"\\nObservation: {observation}\\nThought: \"\n", + " # Set the agent_scratchpad variable to that value\n", + " kwargs[\"agent_scratchpad\"] = thoughts\n", + " # Create a tools variable from the list of tools provided\n", + " kwargs[\"tools\"] = \"\\n\".join(\n", + " [f\"{tool.name}: {tool.description}\" for tool in self.tools]\n", + " )\n", + " # Create a list of tool names for the tools provided\n", + " kwargs[\"tool_names\"] = \", \".join([tool.name for tool in self.tools])\n", + " return self.template.format(**kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f97dca78-fdde-4a70-9137-e34a21d14e64", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = CustomPromptTemplate(\n", + " template=template,\n", + " tools=tools,\n", + " # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n", + " # This includes the `intermediate_steps` variable because that is needed\n", + " input_variables=[\"input\", \"intermediate_steps\"],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "12c57d77-3c1e-4cde-9a83-7d2134392479", + "metadata": {}, + "source": ["## Output parser \n", "This is unchanged from langchain docs"] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "42da05eb-c103-4649-9d20-7143a8880721", + "metadata": {}, + "outputs": [], + "source": [ + "class CustomOutputParser(AgentOutputParser):\n", + " def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n", + " # Check if agent should finish\n", + " if \"Final Answer:\" in llm_output:\n", + " return AgentFinish(\n", + " # Return values is generally always a dictionary with a single `output` key\n", + " # It is not recommended to try anything else at the moment :)\n", + " return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n", + " log=llm_output,\n", + " )\n", + " # Parse out the action and action input\n", + " regex = r\"Action: (.*?)[\\n]*Action Input:[\\s]*(.*)\"\n", + " match = re.search(regex, llm_output, re.DOTALL)\n", + " if not match:\n", + " raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n", + " action = match.group(1).strip()\n", + " action_input = match.group(2)\n", + " # Return the action and action input\n", + " return AgentAction(\n", + " tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d2b4d710-8cc9-4040-9269-59cf6c5c22be", + "metadata": {}, + "outputs": [], + "source": ["output_parser = CustomOutputParser()"] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "48a758cb-93a7-4555-b69a-896d2d43c6f0", + "metadata": {}, + "source": ["## Specify the LLM model"] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "72988c79-8f60-4b0f-85ee-6af32e8de9c2", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-4\", temperature=0)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "95685d14-647a-4e24-ae2c-a8dd1e364921", + "metadata": {}, + "source": ["## Agent and agent executor"] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "13d55765-bfa1-43b3-b7cb-00f52ebe7747", + "metadata": {}, + "outputs": [], + "source": [ + "# LLM chain consisting of the LLM and a prompt\n", + "llm_chain = LLMChain(llm=llm, prompt=prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b3f7ac3c-398e-49f9-baed-554f49a191c3", + "metadata": {}, + "outputs": [], + "source": [ + "tool_names = [tool.name for tool in tools]\n", + "agent = LLMSingleActionAgent(\n", + " llm_chain=llm_chain,\n", + " output_parser=output_parser,\n", + " stop=[\"\\nObservation:\"],\n", + " allowed_tools=tool_names,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "65740577-272e-4853-8d47-b87784cfaba0", + "metadata": {}, + "outputs": [], + "source": [ + "agent_executor = AgentExecutor.from_agent_and_tools(\n", + " agent=agent, tools=tools, verbose=True\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "66e3d13b-77cf-41d3-b541-b54535c14459", + "metadata": {}, + "source": ["## Run it!"] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6e97a07c-d7bf-4a35-9ab2-b59ae865c62c", + "metadata": {}, + "outputs": [], + "source": [ + "# If you prefer in-line tracing, uncomment this line\n", + "# agent_executor.agent.llm_chain.verbose = True" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a11ca60d-f57b-4fe8-943e-a258e37463c7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: I need to find the Q number for J.S. Bach.\n", + "Action: ItemLookup\n", + "Action Input: J.S. Bach\u001b[0m\n", + "\n", + "Observation:\u001b[36;1m\u001b[1;3mQ1339\u001b[0m\u001b[32;1m\u001b[1;3mI need to find the P number for children.\n", + "Action: PropertyLookup\n", + "Action Input: children\u001b[0m\n", + "\n", + "Observation:\u001b[33;1m\u001b[1;3mP1971\u001b[0m\u001b[32;1m\u001b[1;3mNow I can query the number of children J.S. Bach had.\n", + "Action: SparqlQueryRunner\n", + "Action Input: SELECT ?children WHERE { wd:Q1339 wdt:P1971 ?children }\u001b[0m\n", + "\n", + "Observation:\u001b[38;5;200m\u001b[1;3m[{\"children\": {\"datatype\": \"http://www.w3.org/2001/XMLSchema#decimal\", \"type\": \"literal\", \"value\": \"20\"}}]\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer.\n", + "Final Answer: J.S. Bach had 20 children.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": ["'J.S. Bach had 20 children.'"] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\"How many children did J.S. Bach have?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d0b42a41-996b-4156-82e4-f0651a87ee34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThought: To find Hakeem Olajuwon's Basketball-Reference.com NBA player ID, I need to first find his Wikidata item (Q-number) and then query for the relevant property (P-number).\n", + "Action: ItemLookup\n", + "Action Input: Hakeem Olajuwon\u001b[0m\n", + "\n", + "Observation:\u001b[36;1m\u001b[1;3mQ273256\u001b[0m\u001b[32;1m\u001b[1;3mNow that I have Hakeem Olajuwon's Wikidata item (Q273256), I need to find the P-number for the Basketball-Reference.com NBA player ID property.\n", + "Action: PropertyLookup\n", + "Action Input: Basketball-Reference.com NBA player ID\u001b[0m\n", + "\n", + "Observation:\u001b[33;1m\u001b[1;3mP2685\u001b[0m\u001b[32;1m\u001b[1;3mNow that I have both the Q-number for Hakeem Olajuwon (Q273256) and the P-number for the Basketball-Reference.com NBA player ID property (P2685), I can run a SPARQL query to get the ID value.\n", + "Action: SparqlQueryRunner\n", + "Action Input: \n", + "SELECT ?playerID WHERE {\n", + " wd:Q273256 wdt:P2685 ?playerID .\n", + "}\u001b[0m\n", + "\n", + "Observation:\u001b[38;5;200m\u001b[1;3m[{\"playerID\": {\"type\": \"literal\", \"value\": \"o/olajuha01\"}}]\u001b[0m\u001b[32;1m\u001b[1;3mI now know the final answer\n", + "Final Answer: Hakeem Olajuwon's Basketball-Reference.com NBA player ID is \"o/olajuha01\".\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Hakeem Olajuwon\\'s Basketball-Reference.com NBA player ID is \"o/olajuha01\".'" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor.run(\n", + " \"What is the Basketball-Reference.com NBA player ID of Hakeem Olajuwon?\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "05fb3a3e-8a9f-482d-bd54-4c6e60ef60dd", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + } +} diff --git a/crates/cli_bin/tests/apply.rs b/crates/cli_bin/tests/apply.rs index 2abb178b0..ade3541f0 100644 --- a/crates/cli_bin/tests/apply.rs +++ b/crates/cli_bin/tests/apply.rs @@ -851,6 +851,48 @@ fn basic_python_apply() -> Result<()> { Ok(()) } +#[test] +fn python_in_notebook() -> Result<()> { + // Keep _temp_dir around so that the tempdir is not deleted + let (_temp_dir, dir) = get_fixture("notebooks", false)?; + + // from the tempdir as cwd, run init + run_init(&dir.as_path())?; + + // from the tempdir as cwd, run marzano apply + let mut apply_cmd = get_test_cmd()?; + apply_cmd.current_dir(dir.as_path()); + apply_cmd + .arg("apply") + .arg("--force") + .arg("pattern.grit") + .arg("tiny_nb.ipynb"); + let output = apply_cmd.output()?; + + let stdout = String::from_utf8(output.stdout)?; + println!("stdout: {:?}", stdout); + let stderr = String::from_utf8(output.stderr)?; + println!("stderr: {:?}", stderr); + + assert!( + output.status.success(), + "Command didn't finish successfully: {}", + stderr + ); + + // Read back tiny_nb.ipynb + let target_file = dir.join("tiny_nb.ipynb"); + let content: String = fs_err::read_to_string(target_file)?; + + // assert that it matches snapshot + println!("content: {:?}", content); + assert!(content.contains("flint(4)")); + assert!(content.contains("flint(5)")); + assert_snapshot!(content); + + Ok(()) +} + #[test] fn basic_js_in_vue_apply() -> Result<()> { // Keep _temp_dir around so that the tempdir is not deleted diff --git a/crates/cli_bin/tests/snapshots/apply__python_in_notebook.snap b/crates/cli_bin/tests/snapshots/apply__python_in_notebook.snap new file mode 100644 index 000000000..3708d1e81 --- /dev/null +++ b/crates/cli_bin/tests/snapshots/apply__python_in_notebook.snap @@ -0,0 +1,39 @@ +--- +source: crates/cli_bin/tests/apply.rs +expression: content +--- +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": ["flint(4)\nflint(\"3\")\nflint(5)"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/core/src/api.rs b/crates/core/src/api.rs index 6a447c288..5292ab866 100644 --- a/crates/core/src/api.rs +++ b/crates/core/src/api.rs @@ -353,6 +353,25 @@ impl EntireFile { byte_ranges: byte_range.map(|r| r.to_owned()), } } + + fn from_file(file: &FileOwner) -> Result { + if let Some(source_map) = &file.tree.source_map { + let outer_source = source_map.fill_with_inner(&file.tree.source)?; + + Ok(Self::file_to_entire_file( + file.name.to_string_lossy().as_ref(), + &outer_source, + // Exclude the matches, since they aren't reliable yet + None, + )) + } else { + Ok(Self::file_to_entire_file( + file.name.to_string_lossy().as_ref(), + file.tree.outer_source(), + file.matches.borrow().byte_ranges.as_ref(), + )) + } + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] @@ -375,7 +394,7 @@ impl From for MatchResult { impl Rewrite { fn file_to_rewrite<'a>( initial: &FileOwner, - rewrite: &FileOwner, + rewritten_file: &FileOwner, language: &impl MarzanoLanguage<'a>, ) -> Result { let original = if let Some(ranges) = &initial.matches.borrow().input_matches { @@ -388,11 +407,7 @@ impl Rewrite { } else { bail!("cannot have rewrite without matches") }; - let rewritten = EntireFile::file_to_entire_file( - rewrite.name.to_string_lossy().as_ref(), - &rewrite.tree.source, - rewrite.matches.borrow().byte_ranges.as_ref(), - ); + let rewritten = EntireFile::from_file(rewritten_file)?; Ok(Rewrite::new(original, rewritten)) } } diff --git a/crates/core/src/effects_dependency_graph.rs b/crates/core/src/effects_dependency_graph.rs deleted file mode 100644 index c71c568ef..000000000 --- a/crates/core/src/effects_dependency_graph.rs +++ /dev/null @@ -1,464 +0,0 @@ -#![allow(warnings)] -use anyhow::{bail, Result}; -use grit_pattern_matcher::intervals::{ - earliest_deadline_sort, get_top_level_intervals, pop_out_of_range_intervals, Interval, -}; -use std::{ - collections::{HashMap, HashSet}, - vec, -}; - -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct FileInterval { - file: String, - start: u32, - end: u32, -} - -impl Interval for FileInterval { - fn interval(&self) -> (u32, u32) { - (self.start, self.end) - } -} - -impl FileInterval { - fn new(file: String, start: u32, end: u32) -> Self { - Self { file, start, end } - } -} - -#[derive(Debug, Clone)] -pub struct EffectInterval { - left: FileInterval, - right: Vec, -} - -impl Interval for EffectInterval { - fn interval(&self) -> (u32, u32) { - (self.left.start, self.left.end) - } -} - -impl EffectInterval { - fn new(left: FileInterval, right: Vec) -> Self { - Self { left, right } - } -} - -pub trait ToEffectInterval { - fn to_effect_interval(&self) -> EffectInterval; -} - -pub struct EffectsInfo { - interval_map: HashMap)>, - file_to_effect_intervals: HashMap>, - rhs_to_lhs: HashMap>, - file_to_sorted_intervals: HashMap>, -} - -// takes a vectore of effects, and returns relevant datastructures -// Map from lhs of effect as interval to effect -// Map from file to EffectIntervals -// Map from rhs intervals to corresponding lhs Interval -fn effects_to_intervals(effects: Vec) -> Result> -where - T: ToEffectInterval + Clone, -{ - let mut interval_map: HashMap)> = HashMap::new(); - let mut file_to_effect_intervals: HashMap> = HashMap::new(); - let mut rhs_to_lhs: HashMap> = HashMap::new(); - let mut file_to_sorted_intervals: HashMap> = HashMap::new(); - for effect in effects { - let effect_interval = effect.to_effect_interval(); - let lhs = effect_interval.left; - let rhs = effect_interval.right; - let file = lhs.file.clone(); - let old = interval_map.insert(lhs.clone(), (effect.clone(), rhs.clone())); - if old.is_some() { - bail!("duplicate lhs interval"); - } - file_to_effect_intervals - .entry(file.to_owned()) - .or_insert_with(std::vec::Vec::new) - .push(effect.clone()); - file_to_sorted_intervals - .entry(file.to_owned()) - .or_insert_with(std::collections::HashSet::new) - .insert(lhs.clone()); - for interval in rhs.clone() { - file_to_sorted_intervals - .entry(file.to_owned()) - .or_insert_with(std::collections::HashSet::new) - .insert(interval.clone()); - } - for interval in rhs { - rhs_to_lhs - .entry(interval) - .or_insert_with(std::vec::Vec::new) - .push(lhs.clone()); - } - } - let mut file_to_sorted_intervals: HashMap> = file_to_sorted_intervals - .into_iter() - .map(|(k, v)| (k, v.into_iter().collect())) - .collect(); - for intervals in file_to_sorted_intervals.values_mut() { - if !earliest_deadline_sort(intervals) { - bail!("effects have overlapping lhs intervals"); - } - } - let res = EffectsInfo { - interval_map, - file_to_effect_intervals, - rhs_to_lhs, - file_to_sorted_intervals, - }; - Ok(res) -} - -// takes a vector of EffectIntervals and returns a vector of EffectIntervals -// whose lhs are not contained in any other lhs -fn filter_top_level_effects(effects: &mut [EffectInterval]) -> Result> { - if !earliest_deadline_sort(effects) { - bail!("effects have overlapping lhs intervals"); - } - Ok(get_top_level_intervals(effects.to_vec())) -} - -fn top_level_effects_from_all_files( - effects: &mut HashMap>, -) -> Result> { - let res: Result>> = effects - .values_mut() - .map(|es| filter_top_level_effects(es)) - .collect(); - let res = res?; - Ok(res.into_iter().flatten().collect()) -} - -pub fn get_effects_order(effects: Vec) -> Result<(EffectsInfo, Vec)> -where - T: ToEffectInterval + Clone, -{ - let info = effects_to_intervals(effects)?; - let lhs_intervals = info.interval_map.keys().cloned().collect::>(); - let by_file = &info.file_to_sorted_intervals; - for (file, sorted) in by_file { - println!("file: {}", file); - let to_print = sorted.iter().map(|i| (i.start, i.end)).collect::>(); - println!("intervals: {:?}", to_print) - } - let graph = build_dependency_graph(&lhs_intervals, by_file, &info.rhs_to_lhs); - let linearized = linearize_graph(graph)?; - Ok((info, linearized)) -} - -fn build_dependency_graph( - effects: &[FileInterval], - by_file: &HashMap>, - rhs_to_lhs: &HashMap>, -) -> HashMap> { - let mut map = effects - .iter() - .map(|e| (e.to_owned(), HashSet::new())) - .collect::>>(); - for intervals in by_file.values() { - add_dependencies_for_file(intervals, &mut map, rhs_to_lhs); - } - map -} - -// assumes intervals are already EDS sorted; -fn add_dependencies_for_file( - intervals: &[FileInterval], - map: &mut HashMap>, - rhs_to_lhs: &HashMap>, -) { - let mut lhs_stack: Vec = vec![]; - let mut rhs_stack: Vec = vec![]; - for e in intervals.iter().rev() { - pop_out_of_range_intervals(e, &mut lhs_stack); - pop_out_of_range_intervals(e, &mut rhs_stack); - // ORDER MATTERS HERE - // if a range is both lhs and rhs, then the effects - // corresponding to the rhs depend on the effects corresponding to the lhs - // so pushing onto rhs_stack first ensures that the rhs effects of the interval - // are prosseced in the event that it is also lhs. - if let Some(sources) = rhs_to_lhs.get(e) { - // adds all the effects corresponding which have e on the rhs - // as dependencies of all effects with lhs enclosing e - for lhs in lhs_stack.iter() { - let old = map - .entry(lhs.clone()) - .or_insert_with(std::collections::HashSet::new); - old.extend(sources.clone()); - } - rhs_stack.push(e.clone()); - } - if map.contains_key(e) { - // adds e as a dependency to all effects whose lhs encloses e - for lhs in lhs_stack.iter() { - let old = map - .entry(lhs.clone()) - .or_insert_with(std::collections::HashSet::new); - old.insert(e.clone()); - } - // adds e as a dependency to all effects whose rhs encloses e - for rhs in rhs_stack.iter() { - // should always be true - if let Some(sources) = rhs_to_lhs.get(rhs) { - for source in sources { - let old = map - .entry(source.clone()) - .or_insert_with(std::collections::HashSet::new); - old.insert(e.clone()); - } - } - } - lhs_stack.push(e.clone()); - } - } -} - -fn linearize_graph( - mut dependency_graph: HashMap>, -) -> Result> { - let mut dependants: HashMap> = HashMap::new(); - let mut dependency_free: Vec = vec![]; - let mut linearized: Vec = vec![]; - for (interval, dependencies) in &dependency_graph { - if dependencies.is_empty() { - dependency_free.push(interval.clone()); - } - for dependency in dependencies { - let old = dependants - .entry(dependency.clone()) - .or_insert_with(std::collections::HashSet::new); - old.insert(interval.clone()); - } - } - while let Some(interval) = dependency_free.pop() { - linearized.push(interval.clone()); - if let Some(dependants) = dependants.get(&interval) { - for dependant in dependants { - let old = dependency_graph - .get_mut(dependant) - .expect("dependant not in dependency graph"); - old.remove(&interval); - if old.is_empty() { - dependency_free.push(dependant.clone()); - } - } - } - } - if linearized.len() != dependency_graph.len() { - bail!("dependency graph has a cycle"); - } - Ok(linearized) -} - -#[cfg(test)] -mod tests { - - use super::{ - get_effects_order, linearize_graph, EffectInterval, FileInterval, ToEffectInterval, - }; - use grit_pattern_matcher::intervals::{earliest_deadline_sort, Interval}; - use std::collections::{HashMap, HashSet}; - - type NestedVec = Vec<((u32, u32), Vec<(u32, u32)>)>; - type NestedArray = [((u32, u32), Vec<(u32, u32)>)]; - - fn vec_into(v: &[(u32, u32)]) -> Vec { - v.iter().map(interval_to_file).collect() - } - - fn vec_to_set(v: &[(u32, u32)]) -> std::collections::HashSet { - v.iter().map(interval_to_file).collect() - } - - fn interval_to_file(e: &(u32, u32)) -> FileInterval { - FileInterval::new("default".to_owned(), e.0, e.1) - } - - fn nested_vec_to_map( - v: &NestedArray, - ) -> std::collections::HashMap> { - v.iter() - .map(|(lhs, rhs)| (interval_to_file(lhs), vec_into(rhs))) - .collect() - } - - #[allow(dead_code)] - fn map_to_vec(map: &HashMap>) -> NestedVec { - map.iter() - .map(|(lhs, rhs)| (lhs.interval(), rhs.iter().map(|f| f.interval()).collect())) - .collect() - } - - fn assert_map(map: &HashMap>, expected: &NestedVec) { - let expected = nested_vec_to_map(expected) - .iter() - .map(|(k, v)| (k.to_owned(), v.iter().map(|e| e.to_owned()).collect())) - .collect(); - assert_eq!(map, &expected); - } - - #[allow(dead_code)] - fn print_res(res: NestedVec) { - for (lhs, rhs) in res { - println!("{:?} -> {:?}", lhs, rhs); - } - } - - fn init_map_from_vec( - lhs_intervals: &[(u32, u32)], - ) -> HashMap> { - let lhs_intervals = vec_to_set(lhs_intervals); - let mut map = HashMap::new(); - for lhs in lhs_intervals { - map.insert(lhs.clone(), HashSet::new()); - } - map - } - - fn dependency_tester( - intervals: &mut [(u32, u32)], - map: &mut HashMap>, - rhs_to_lhs: &NestedArray, - ) { - let mut intervals = vec_into(intervals); - assert!(earliest_deadline_sort(&mut intervals)); - let rhs_to_lhs = nested_vec_to_map(rhs_to_lhs); - super::add_dependencies_for_file(&intervals, map, &rhs_to_lhs); - } - - #[test] - fn no_nesting_test() { - let intervals = &mut [(0, 1), (2, 3), (4, 5), (6, 7)]; - let lhs_intervals = &[(0, 1), (2, 3), (4, 5), (6, 7)]; - let rhs_to_lhs = &[ - ((0, 1), vec![(2, 3), (4, 5), (6, 7)]), - ((2, 3), vec![(4, 5), (6, 7)]), - ((4, 5), vec![(6, 7)]), - ]; - let mut map = init_map_from_vec(lhs_intervals); - dependency_tester(intervals, &mut map, rhs_to_lhs); - let expected = vec![ - ((6, 7), vec![(0, 1), (2, 3), (4, 5)]), - ((2, 3), vec![(0, 1)]), - ((4, 5), vec![(0, 1), (2, 3)]), - ((0, 1), vec![]), - ]; - assert_map(&map, &expected); - } - - #[derive(Debug, Clone)] - struct EffectIntervalTest { - left: FileInterval, - right: Vec, - } - - impl EffectIntervalTest { - fn new(left: FileInterval, right: Vec) -> Self { - Self { left, right } - } - } - - fn default_file_effect(l: (u32, u32), r: &[(u32, u32)]) -> EffectIntervalTest { - let left = FileInterval::new("default".to_owned(), l.0, l.1); - let right = r - .iter() - .map(|e| FileInterval::new("default".to_owned(), e.0, e.1)) - .collect(); - EffectIntervalTest::new(left, right) - } - - fn default_file_array(intervals: &NestedArray) -> Vec { - intervals - .iter() - .map(|(l, r)| default_file_effect(*l, r)) - .collect() - } - - impl ToEffectInterval for EffectIntervalTest { - fn to_effect_interval(&self) -> EffectInterval { - let left = self.left.to_owned(); - let right = self.right.iter().map(|e| e.to_owned()).collect(); - EffectInterval::new(left, right) - } - } - - #[allow(dead_code)] - fn vec_back(v: &[T]) -> Vec<(u32, u32)> - where - T: super::Interval, - { - v.iter().map(|e| e.interval()).collect() - } - - #[test] - fn nested_intervals_lhs_only_test() { - let intervals = &mut [(0, 1), (2, 5), (0, 2), (2, 4), (3, 4), (1, 2), (2, 3)]; - let lhs_intervals = &[(0, 1), (2, 5), (0, 2), (2, 4), (3, 4), (1, 2), (2, 3)]; - let rhs_to_lhs = &[]; - let mut map = init_map_from_vec(lhs_intervals); - dependency_tester(intervals, &mut map, rhs_to_lhs); - let expected = vec![ - ((0, 2), vec![(0, 1), (1, 2)]), - ((2, 4), vec![(2, 3), (3, 4)]), - ((2, 5), vec![(2, 3), (2, 4), (3, 4)]), - ((3, 4), vec![]), - ((1, 2), vec![]), - ((2, 3), vec![]), - ((0, 1), vec![]), - ]; - assert_map(&map, &expected); - } - #[test] - fn nested_intervals_test() { - let intervals = &mut [(0, 1), (2, 5), (0, 2), (2, 4), (3, 4), (1, 2), (2, 3)]; - // if we were to filter out top level intervals, (0, 2) would make (0, 1) and (1, 2) - // and we would expect to remove them from the graph. - let lhs_intervals = &[(0, 1), (1, 2), (0, 2), (2, 3), (3, 4), (2, 5)]; - let rhs_to_lhs = &[ - ((2, 4), vec![(0, 1)]), - ((2, 5), vec![(1, 2)]), - ((3, 4), vec![(2, 3)]), - ]; - let mut map = init_map_from_vec(lhs_intervals); - dependency_tester(intervals, &mut map, rhs_to_lhs); - let expected = vec![ - ((0, 1), vec![(2, 3), (3, 4)]), - ((1, 2), vec![(2, 5), (3, 4), (2, 3)]), - ((0, 2), vec![(0, 1), (1, 2)]), - ((2, 5), vec![(0, 1), (2, 3), (3, 4)]), - ((2, 3), vec![(3, 4)]), - ((3, 4), vec![]), - ]; - assert_map(&map, &expected); - let linear = linearize_graph(map).unwrap(); - assert_eq!( - linear, - vec_into(&[(3, 4), (2, 3), (0, 1), (2, 5), (1, 2), (0, 2)]) - ); - } - - #[test] - fn linearize_effects() { - let effects = default_file_array(&[ - ((2, 5), vec![]), - ((1, 2), vec![(2, 5)]), - ((0, 2), vec![]), - ((3, 4), vec![]), - ((0, 1), vec![(2, 4)]), - ((2, 3), vec![(3, 4)]), - ]); - let (_info, linear) = get_effects_order(effects).unwrap(); - - assert_eq!( - linear, - vec_into(&[(3, 4), (2, 3), (0, 1), (2, 5), (1, 2), (0, 2)]) - ); - } -} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index e3da12b26..96a156bb3 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -6,7 +6,6 @@ pub mod built_in_functions; mod clean; pub mod compact_api; pub mod constants; -mod effects_dependency_graph; mod equivalence; mod foreign_function_definition; pub mod fs; @@ -34,6 +33,8 @@ mod variables; // and here we import it to avoid an unused dependency warning #[cfg(feature = "wasm_core")] use getrandom as _; +#[cfg(test)] +mod test_notebooks; #[cfg(test)] mod test; diff --git a/crates/core/src/marzano_context.rs b/crates/core/src/marzano_context.rs index 39771e730..6f67f6e22 100644 --- a/crates/core/src/marzano_context.rs +++ b/crates/core/src/marzano_context.rs @@ -19,7 +19,7 @@ use grit_pattern_matcher::{ PredicateDefinition, ResolvedPattern, State, }, }; -use grit_util::{AnalysisLogs, Ast, InputRanges, MatchRanges}; +use grit_util::{AnalysisLogs, Ast, FileOrigin, InputRanges, MatchRanges}; use im::vector; use marzano_language::{ language::{MarzanoLanguage, Tree}, @@ -148,7 +148,7 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { owned.path, owned.content, None, - false, + FileOrigin::Fresh, self.language, logs, )?; @@ -247,8 +247,11 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { self, logs, )?; + if let Some(new_ranges) = new_ranges { - let tree = parser.parse_file(&new_src, None, logs, true).unwrap(); + let tree = parser + .parse_file(&new_src, None, logs, FileOrigin::Mutated(&file.tree)) + .unwrap(); let root = tree.root_node(); let replacement_ranges = get_replacement_ranges(root, self.language()); let cleaned_src = replace_cleaned_ranges(replacement_ranges, &new_src)?; @@ -260,11 +263,11 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { let ranges = MatchRanges::new(new_ranges.into_iter().map(|r| r.into()).collect()); - let owned_file = FileOwnerCompiler::from_matches( + let rewritten_file = FileOwnerCompiler::from_matches( new_filename.clone(), new_src, Some(ranges), - true, + FileOrigin::Mutated(&file.tree), self.language(), logs, )? @@ -274,7 +277,8 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { new_filename.to_string_lossy() ) })?; - self.files().push(owned_file); + + self.files().push(rewritten_file); state .files .push_revision(&file_ptr, self.files().last().unwrap()) @@ -310,7 +314,7 @@ impl<'a> ExecContext<'a, MarzanoQueryContext> for MarzanoContext<'a> { name.clone(), body, None, - true, + FileOrigin::New, self.language(), logs, )? diff --git a/crates/core/src/parse.rs b/crates/core/src/parse.rs index 4296118fd..1771a7c87 100644 --- a/crates/core/src/parse.rs +++ b/crates/core/src/parse.rs @@ -16,7 +16,12 @@ pub fn parse_input_file<'a>( let mut parser = lang.get_parser(); let tree = parser - .parse_file(input, Some(path), &mut vec![].into(), false) + .parse_file( + input, + Some(path), + &mut vec![].into(), + grit_util::FileOrigin::Fresh, + ) .context("Parsed tree is empty")?; let input_file_debug_text = to_string_pretty(&tree_sitter_node_to_json( &tree.root_node().node, @@ -29,6 +34,7 @@ pub fn parse_input_file<'a>( syntax_tree: input_file_debug_text, }) } + #[cfg(not(feature = "grit-parser"))] pub fn parse_input_file<'a>( _lang: &impl MarzanoLanguage<'a>, @@ -41,3 +47,60 @@ pub fn parse_input_file<'a>( "enable grit-parser feature flag to parse a grit file" )) } + +#[cfg(test)] +mod tests { + use std::path::Path; + + use grit_util::{traverse, Ast, FileOrigin, Order}; + use insta::assert_snapshot; + use marzano_language::language::MarzanoLanguage; + use marzano_language::target_language::TargetLanguage; + use marzano_util::cursor_wrapper::CursorWrapper; + + fn verify_notebook(source: &str, path: &Path) -> String { + let lang = TargetLanguage::from_string("ipynb", None).unwrap(); + + let mut parser = lang.get_parser(); + let tree = parser + .parse_file(source, Some(path), &mut vec![].into(), FileOrigin::Fresh) + .unwrap(); + + let mut simple_rep = String::new(); + + let cursor = tree.root_node().node.walk(); + for n in traverse(CursorWrapper::new(cursor, source), Order::Pre) { + simple_rep += format!( + "{:, source: String, matches: Option, - new: bool, + old_tree: FileOrigin<'_, Tree>, language: &impl MarzanoLanguage<'a>, logs: &mut AnalysisLogs, ) -> Result>> { let name = name.into(); - let Some(tree) = language + let new = !old_tree.is_fresh(); + + // If we have an old tree, attach it here + let new_map = if let Some(old_tree) = old_tree.original() { + // TODO: avoid this clone + old_tree.source_map.clone() + } else { + None + }; + + let Some(mut tree) = language .get_parser() - .parse_file(&source, Some(&name), logs, new) + .parse_file(&source, Some(&name), logs, old_tree) else { return Ok(None); }; + + if new_map.is_some() { + tree.source_map = new_map; + } + let absolute_path = absolutize(&name)?; Ok(Some(FileOwner { name, diff --git a/crates/core/src/problem.rs b/crates/core/src/problem.rs index 2e63548c0..b660074c4 100644 --- a/crates/core/src/problem.rs +++ b/crates/core/src/problem.rs @@ -19,7 +19,7 @@ use grit_pattern_matcher::{ PredicateDefinition, ResolvedPattern, State, VariableContent, }, }; -use grit_util::{VariableMatch}; +use grit_util::VariableMatch; use im::vector; use log::error; use marzano_language::{language::Tree, target_language::TargetLanguage}; diff --git a/crates/core/src/snapshots/marzano_core__parse__tests__other_notebook.snap b/crates/core/src/snapshots/marzano_core__parse__tests__other_notebook.snap new file mode 100644 index 000000000..645f181a5 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__parse__tests__other_notebook.snap @@ -0,0 +1,32 @@ +--- +source: crates/core/src/parse.rs +expression: "verify_notebook(source, path)" +--- +module | import configparser\n\nconfig = configparser.ConfigParser()\nconfig.read("./secrets.ini") +import_statement | import configparser +import | import +dotted_name | configparser +identifier | configparser +assignment | config = configparser.ConfigParser() +identifier | config += | = +call | configparser.ConfigParser() +attribute | configparser.ConfigParser +identifier | configparser +. | . +identifier | ConfigParser +argument_list | () +( | ( +) | ) +call | config.read("./secrets.ini") +attribute | config.read +identifier | config +. | . +identifier | read +argument_list | ("./secrets.ini") +( | ( +string | "./secrets.ini" +string_start | " +string_content | ./secrets.ini +string_end | " +) | ) diff --git a/crates/core/src/snapshots/marzano_core__parse__tests__simple_notebook.snap b/crates/core/src/snapshots/marzano_core__parse__tests__simple_notebook.snap new file mode 100644 index 000000000..8a8aca0e9 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__parse__tests__simple_notebook.snap @@ -0,0 +1,26 @@ +--- +source: crates/core/src/parse.rs +expression: "verify_notebook(source, path)" +--- +module | print(4)\nprint("3")\nprint(5) +call | print(4) +identifier | print +argument_list | (4) +( | ( +integer | 4 +) | ) +call | print("3") +identifier | print +argument_list | ("3") +( | ( +string | "3" +string_start | " +string_content | 3 +string_end | " +) | ) +call | print(5) +identifier | print +argument_list | (5) +( | ( +integer | 5 +) | ) diff --git a/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook-2.snap b/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook-2.snap new file mode 100644 index 000000000..aa9e7d99e --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook-2.snap @@ -0,0 +1,32 @@ +--- +source: crates/core/src/parse.rs +expression: simple_rep +--- +module | import configparser\n\nconfig = configparser.ConfigParser()\nconfig.read("./secrets.ini") +import_statement | import configparser +import | import +dotted_name | configparser +identifier | configparser +assignment | config = configparser.ConfigParser() +identifier | config += | = +call | configparser.ConfigParser() +attribute | configparser.ConfigParser +identifier | configparser +. | . +identifier | ConfigParser +argument_list | () +( | ( +) | ) +call | config.read("./secrets.ini") +attribute | config.read +identifier | config +. | . +identifier | read +argument_list | ("./secrets.ini") +( | ( +string | "./secrets.ini" +string_start | " +string_content | ./secrets.ini +string_end | " +) | ) diff --git a/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook.snap b/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook.snap new file mode 100644 index 000000000..e642a0933 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__parse__tests__verify_notebook.snap @@ -0,0 +1,26 @@ +--- +source: crates/core/src/parse.rs +expression: simple_rep +--- +module | print(4)\nprint("3")\nprint(5) +call | print(4) +identifier | print +argument_list | (4) +( | ( +integer | 4 +) | ) +call | print("3") +identifier | print +argument_list | ("3") +( | ( +string | "3" +string_start | " +string_content | 3 +string_end | " +) | ) +call | print(5) +identifier | print +argument_list | (5) +( | ( +integer | 5 +) | ) diff --git a/crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap b/crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap new file mode 100644 index 000000000..0725fa419 --- /dev/null +++ b/crates/core/src/snapshots/marzano_core__test_notebooks__base_case.snap @@ -0,0 +1,39 @@ +--- +source: crates/core/src/test_notebooks.rs +expression: rewrite.rewritten.content +--- +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e3cb542-933d-4bf3-a82b-d9d6395a7832", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": ["flink(4)\nflink(\"3\")\nflink(5)"] + } + ], + "metadata": { + "kernelspec": { + "display_name": "conda210", + "language": "python", + "name": "conda210" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/crates/core/src/test.rs b/crates/core/src/test.rs index 00472422a..78616e25b 100644 --- a/crates/core/src/test.rs +++ b/crates/core/src/test.rs @@ -12052,7 +12052,7 @@ fn python_orphaned_from_imports() { | } else { | $anchor => `from $to_package import $replacement_name` | } - | } + | } | }, | `import $name` as $anchor where { | // Split the name into its constituent parts diff --git a/crates/core/src/test_files.rs b/crates/core/src/test_files.rs index 8caf834d9..8d014589a 100644 --- a/crates/core/src/test_files.rs +++ b/crates/core/src/test_files.rs @@ -425,7 +425,6 @@ fn avoid_unsafe_hoists() { SyntheticFile::new("other.js".to_owned(), matching_src.to_owned(), true), ]; let results = run_on_test_files(&pattern, &test_files); - println!("{:?}", results); // Confirm we have 3 DoneFiles and 1 match assert_eq!(results.len(), 4); assert!(results.iter().any(|r| r.is_match())); diff --git a/crates/core/src/test_notebooks.rs b/crates/core/src/test_notebooks.rs new file mode 100644 index 000000000..b531d7e83 --- /dev/null +++ b/crates/core/src/test_notebooks.rs @@ -0,0 +1,91 @@ +use insta::assert_snapshot; +use marzano_language::target_language::TargetLanguage; + +use crate::{ + api::MatchResult, + test_utils::{run_on_test_files, SyntheticFile}, +}; + +use self::pattern_compiler::src_to_problem_libs; + +use super::*; +use std::collections::BTreeMap; + +#[test] +fn test_base_case() { + let pattern_src = r#" + language python + + `print($x)` => `flink($x)` + "#; + let libs = BTreeMap::new(); + + let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + println!("{:?}", results); + assert!(!results.iter().any(|r| r.is_error())); + + let rewrite = results + .iter() + .find(|r| matches!(r, MatchResult::Rewrite(_))) + .unwrap(); + + if let MatchResult::Rewrite(rewrite) = rewrite { + assert_snapshot!(rewrite.rewritten.content); + } else { + panic!("Expected a rewrite"); + } +} + +#[test] +fn test_old_notebooks() { + let pattern_src = r#" + language python + + `print($x)` => `flink($x)` + "#; + let libs = BTreeMap::new(); + + let matching_src = include_str!("../../../crates/cli_bin/fixtures/notebooks/old_nb.ipynb"); + + let pattern = src_to_problem_libs( + pattern_src.to_string(), + &libs, + TargetLanguage::from_extension("ipynb").unwrap(), + None, + None, + None, + None, + ) + .unwrap() + .problem; + + // Basic match works + let test_files = vec![SyntheticFile::new( + "target.ipynb".to_owned(), + matching_src.to_owned(), + true, + )]; + let results = run_on_test_files(&pattern, &test_files); + // We *do* expect an error on old notebooks + assert!(results.iter().any(|r| r.is_error())); +} diff --git a/crates/grit-util/src/analysis_logs.rs b/crates/grit-util/src/analysis_logs.rs index 867ead9e5..59b5041b5 100644 --- a/crates/grit-util/src/analysis_logs.rs +++ b/crates/grit-util/src/analysis_logs.rs @@ -35,6 +35,15 @@ impl AnalysisLogs { pub fn new() -> Self { Self(Vec::new()) } + + pub fn add_warning(&mut self, file: Option, message: impl Into) { + self.0.push(AnalysisLog { + level: Some(339), + file, + message: message.into(), + ..AnalysisLog::default() + }); + } } impl Default for AnalysisLogs { diff --git a/crates/grit-util/src/lib.rs b/crates/grit-util/src/lib.rs index 086ce0705..1ebfa1241 100644 --- a/crates/grit-util/src/lib.rs +++ b/crates/grit-util/src/lib.rs @@ -13,7 +13,7 @@ pub use ast_node::AstNode; pub use ast_node_traversal::{traverse, AstCursor, Order}; pub use code_range::CodeRange; pub use language::{GritMetaValue, Language, Replacement}; -pub use parser::{Ast, Parser, SnippetTree}; +pub use parser::{Ast, FileOrigin, Parser, SnippetTree}; pub use position::Position; pub use ranges::{ ByteRange, FileRange, InputRanges, MatchRanges, Range, RangeWithoutByte, UtilRange, diff --git a/crates/grit-util/src/parser.rs b/crates/grit-util/src/parser.rs index 671b0a244..513fec3eb 100644 --- a/crates/grit-util/src/parser.rs +++ b/crates/grit-util/src/parser.rs @@ -1,6 +1,36 @@ use crate::{AnalysisLogs, AstNode}; use std::path::Path; +/// Information on where a file came from, for the parser to be smarter +#[derive(Clone, Debug, Copy)] +pub enum FileOrigin<'tree, Tree> +where + Tree: Ast, +{ + /// A file we are parsing for the first time, from disk + Fresh, + /// A file we have parsed before, and are re-parsing after mutating + Mutated(&'tree Tree), + /// A file that was constructed by Grit + New, +} + +impl<'tree, Tree: Ast> FileOrigin<'tree, Tree> { + /// Is this a file we are parsing for the first time, from outside Grit? + pub fn is_fresh(&self) -> bool { + matches!(self, FileOrigin::Fresh) + } + + /// Get the original tree, if any + pub fn original(&self) -> Option<&'tree Tree> { + match self { + FileOrigin::Fresh => None, + FileOrigin::Mutated(tree) => Some(tree), + FileOrigin::New => None, + } + } +} + pub trait Parser { type Tree: Ast; @@ -9,7 +39,7 @@ pub trait Parser { body: &str, path: Option<&Path>, logs: &mut AnalysisLogs, - new: bool, + origin: FileOrigin, ) -> Option; fn parse_snippet( diff --git a/crates/gritmodule/src/fetcher.rs b/crates/gritmodule/src/fetcher.rs index f506bdd7c..97bc99306 100644 --- a/crates/gritmodule/src/fetcher.rs +++ b/crates/gritmodule/src/fetcher.rs @@ -5,8 +5,8 @@ use git2::Repository; use regex::Regex; use serde::{Deserialize, Serialize}; -use lazy_static::lazy_static; use fs_err; +use lazy_static::lazy_static; use crate::{config::GRIT_MODULE_DIR, searcher::find_git_dir_from, utils::remove_dir_all_safe}; diff --git a/crates/gritmodule/src/markdown.rs b/crates/gritmodule/src/markdown.rs index 6645f84b2..96ba50fa9 100644 --- a/crates/gritmodule/src/markdown.rs +++ b/crates/gritmodule/src/markdown.rs @@ -6,6 +6,7 @@ use crate::{ utils::is_pattern_name, }; use anyhow::{bail, Context, Result}; +use fs_err::OpenOptions; use grit_util::{traverse, Ast, Order, Position}; use marzano_core::analysis::defines_itself; use marzano_core::api::EnforcementLevel; @@ -14,7 +15,6 @@ use marzano_language::language::MarzanoLanguage as _; use marzano_util::cursor_wrapper::CursorWrapper; use marzano_util::node_with_source::NodeWithSource; use marzano_util::rich_path::RichFile; -use fs_err::OpenOptions; use std::io::{Read, Seek, Write}; use std::path::Path; use tokio::io::SeekFrom; diff --git a/crates/language/src/css.rs b/crates/language/src/css.rs index c02528f69..0b169c391 100644 --- a/crates/language/src/css.rs +++ b/crates/language/src/css.rs @@ -5,7 +5,7 @@ use crate::{ }, vue::get_vue_ranges, }; -use grit_util::{AnalysisLogs, Language, Parser, SnippetTree}; +use grit_util::{AnalysisLogs, FileOrigin, Language, Parser, SnippetTree}; use marzano_util::node_with_source::NodeWithSource; use std::{path::Path, sync::OnceLock}; @@ -119,7 +119,7 @@ impl Parser for MarzanoCssParser { body: &str, path: Option<&Path>, logs: &mut AnalysisLogs, - new: bool, + old_tree: FileOrigin<'_, Tree>, ) -> Option { if path .and_then(Path::extension) @@ -138,7 +138,7 @@ impl Parser for MarzanoCssParser { .ok()? .map(|tree| Tree::new(tree, body)) } else { - self.0.parse_file(body, path, logs, new) + self.0.parse_file(body, path, logs, old_tree) } } @@ -229,7 +229,7 @@ defineProps<{ snippet, Some(Path::new("test.vue")), &mut vec![].into(), - false, + FileOrigin::Fresh, ) .unwrap(); print_node(&tree.root_node().node); diff --git a/crates/language/src/js_like.rs b/crates/language/src/js_like.rs index 2bb00ddd0..42eb8bbc5 100644 --- a/crates/language/src/js_like.rs +++ b/crates/language/src/js_like.rs @@ -6,7 +6,7 @@ use crate::{ }, vue::get_vue_ranges, }; -use grit_util::{AnalysisLogs, AstNode, Parser, Replacement, SnippetTree}; +use grit_util::{AnalysisLogs, AstNode, FileOrigin, Parser, Replacement, SnippetTree}; use marzano_util::node_with_source::NodeWithSource; use std::path::Path; @@ -120,7 +120,7 @@ impl Parser for MarzanoJsLikeParser { body: &str, path: Option<&Path>, logs: &mut AnalysisLogs, - new: bool, + old_tree: FileOrigin<'_, Tree>, ) -> Option { if path .and_then(Path::extension) @@ -137,7 +137,7 @@ impl Parser for MarzanoJsLikeParser { .ok()? .map(|tree| Tree::new(tree, body)) } else { - self.0.parse_file(body, path, logs, new) + self.0.parse_file(body, path, logs, old_tree) } } @@ -256,7 +256,7 @@ defineProps<{ snippet, Some(Path::new("test.vue")), &mut vec![].into(), - false, + FileOrigin::Fresh, ) .unwrap(); print_node(&tree.root_node().node); diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index fa7fe9066..c5263c759 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -1,7 +1,8 @@ use anyhow::{Context, Result}; use enum_dispatch::enum_dispatch; use grit_util::{ - traverse, AnalysisLogBuilder, AnalysisLogs, Ast, AstNode, Language, Order, Parser, SnippetTree, + traverse, AnalysisLogBuilder, AnalysisLogs, Ast, AstNode, FileOrigin, Language, Order, Parser, + SnippetTree, }; use itertools::Itertools; use marzano_util::{cursor_wrapper::CursorWrapper, node_with_source::NodeWithSource}; @@ -9,6 +10,8 @@ use serde_json::Value; use std::{borrow::Cow, cmp::max, collections::HashMap, path::Path}; pub(crate) use tree_sitter::{Language as TSLanguage, Parser as TSParser, Tree as TSTree}; +use crate::sourcemap::EmbeddedSourceMap; + pub type SortId = u16; pub type FieldId = u16; @@ -181,7 +184,10 @@ pub trait NodeTypes { #[derive(Clone, Debug)] pub struct Tree { tree: TSTree, + /// The pure source code of the tree, which does not necessarily match the original file pub source: String, + /// A source map, if needed + pub source_map: Option, } impl Tree { @@ -189,6 +195,14 @@ impl Tree { Self { tree, source: source.into(), + source_map: None, + } + } + + pub fn outer_source(&self) -> &str { + match &self.source_map { + Some(map) => &map.outer_source, + None => &self.source, } } } @@ -230,12 +244,12 @@ impl Parser for MarzanoParser { body: &str, path: Option<&Path>, logs: &mut AnalysisLogs, - new: bool, + old_tree: FileOrigin<'_, Tree>, ) -> Option { let tree = self.parser.parse(body, None).ok()??; if let Some(path) = path { - let mut errors = file_parsing_error(&tree, path, body, new).ok()?; + let mut errors = file_parsing_error(&tree, path, body, !old_tree.is_fresh()).ok()?; logs.append(&mut errors); } diff --git a/crates/language/src/lib.rs b/crates/language/src/lib.rs index f1e170376..90057c218 100644 --- a/crates/language/src/lib.rs +++ b/crates/language/src/lib.rs @@ -13,6 +13,7 @@ pub mod json; pub mod language; pub mod markdown_block; pub mod markdown_inline; +mod notebooks; pub mod php; mod php_like; pub mod php_only; @@ -20,6 +21,7 @@ pub mod python; pub mod ruby; pub mod rust; pub mod solidity; +mod sourcemap; pub mod sql; pub mod target_language; pub mod toml; diff --git a/crates/language/src/notebooks.rs b/crates/language/src/notebooks.rs new file mode 100644 index 000000000..e6b47a349 --- /dev/null +++ b/crates/language/src/notebooks.rs @@ -0,0 +1,242 @@ +use grit_util::Ast; +use grit_util::AstCursor; +use grit_util::AstNode; +use grit_util::ByteRange; +use grit_util::FileOrigin; + +use std::path::Path; + +use grit_util::traverse; +use grit_util::Order; + +use grit_util::{AnalysisLogs, SnippetTree}; +use marzano_util::cursor_wrapper::CursorWrapper; + +use crate::sourcemap::EmbeddedSourceMap; +use crate::sourcemap::SourceMapSection; +use crate::sourcemap::SourceValueFormat; +use crate::{ + json::Json, + language::{MarzanoLanguage, MarzanoParser, Tree}, +}; + +const SUPPORTED_VERSION: i64 = 4; + +/// Custom Python parser, to include notebooks +pub(crate) struct MarzanoNotebookParser(MarzanoParser); + +impl MarzanoNotebookParser { + pub(crate) fn new<'a>(lang: &impl MarzanoLanguage<'a>) -> Self { + Self(MarzanoParser::new(lang)) + } + + fn parse_file_as_notebook( + &mut self, + body: &str, + path: Option<&Path>, + logs: &mut AnalysisLogs, + ) -> Option { + let mut inner_code_body = String::new(); + let mut source_map = EmbeddedSourceMap::new(body); + + let mut nbformat_version: Option = None; + + let json = Json::new(None); + let mut parser = json.get_parser(); + let tree = parser.parse_file(body, None, logs, FileOrigin::Fresh)?; + let root = tree.root_node().node; + let cursor = root.walk(); + + for n in traverse(CursorWrapper::new(cursor, body), Order::Pre) { + if n.node.kind() == "pair" + && n.child_by_field_name("key") + .and_then(|key| key.node.utf8_text(body.as_bytes()).ok()) + .map(|key| key == "\"nbformat\"") + .unwrap_or(false) + { + let value = n + .child_by_field_name("value") + .and_then(|value| value.node.utf8_text(body.as_bytes()).ok()) + .map(|value| value.parse::().unwrap()) + .unwrap_or(0); + if value != SUPPORTED_VERSION { + logs.add_warning( + path.map(|m| m.into()), + format!("Unsupported version {} found", value), + ); + return None; + } + nbformat_version = Some(value); + } + if n.node.kind() != "object" { + continue; + } + + let mut cursor = n.walk(); + + let mut is_code_cell = true; + + let mut source_ranges: Option<(String, SourceMapSection)> = None; + + cursor.goto_first_child(); // Enter the object + while cursor.goto_next_sibling() { + // Iterate over the children of the object + let node = cursor.node(); + + if node.node.kind() == "pair" + && node + .child_by_field_name("key") + .and_then(|key| key.node.utf8_text(body.as_bytes()).ok()) + .map(|key| key == "\"cell_type\"") + .unwrap_or(false) + && node + .child_by_field_name("value") + .and_then(|value| value.node.utf8_text(body.as_bytes()).ok()) + .map(|value| value == "\"code\"") + .unwrap_or(false) + { + is_code_cell = true; + } + + if node.node.kind() == "pair" + && node + .child_by_field_name("key") + .and_then(|key| key.node.utf8_text(body.as_bytes()).ok()) + .map(|key| key == "\"source\"") + .unwrap_or(false) + { + if let Some(value) = node.child_by_field_name("value") { + let range = value.node.range(); + let text = value.node.utf8_text(body.as_bytes()).ok()?; + let value: serde_json::Value = serde_json::from_str(&text).ok()?; + + let (this_content, format) = match value { + serde_json::Value::Array(value) => ( + value + .iter() + .map(|v| v.as_str().unwrap_or("")) + .collect::>() + .join(""), + SourceValueFormat::Array, + ), + serde_json::Value::String(s) => (s, SourceValueFormat::String), + _ => { + logs.add_warning( + path.map(|m| m.into()), + "Unsupported cell source format, expected a string or array of strings".to_string(), + ); + continue; + } + }; + let inner_range = ByteRange::new( + inner_code_body.len(), + inner_code_body.len() + this_content.len(), + ); + source_ranges = Some(( + this_content, + SourceMapSection { + outer_range: ByteRange::new( + range.start_byte().try_into().unwrap(), + range.end_byte().try_into().unwrap(), + ), + inner_range, + format, + }, + )); + } + } + } + + if is_code_cell { + if let Some(source_range) = source_ranges { + let (content, section) = source_range; + inner_code_body.push_str(&content); + source_map.add_section(section); + } + } + + cursor.goto_parent(); // Exit the object + } + + // Confirm we have a version + if nbformat_version.is_none() { + logs.add_warning( + path.map(|m| m.into()), + "No nbformat version found".to_string(), + ); + return None; + } + + self.0 + .parser + .parse(inner_code_body.clone(), None) + .ok()? + .map(|tree| { + let mut tree = Tree::new(tree, inner_code_body); + tree.source_map = Some(source_map); + tree + }) + } +} + +impl grit_util::Parser for MarzanoNotebookParser { + type Tree = Tree; + + fn parse_file( + &mut self, + body: &str, + path: Option<&Path>, + logs: &mut AnalysisLogs, + old_tree: FileOrigin<'_, Tree>, + ) -> Option { + if path + .and_then(Path::extension) + .is_some_and(|ext| ext == "ipynb") + && old_tree.is_fresh() + { + let tree = self.parse_file_as_notebook(body, path, logs); + if let Some(tree) = tree { + return Some(tree); + } + } + + self.0.parse_file(body, path, logs, old_tree) + } + + fn parse_snippet( + &mut self, + pre: &'static str, + source: &str, + post: &'static str, + ) -> SnippetTree { + self.0.parse_snippet(pre, source, post) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::python::Python; + use grit_util::Parser as _; + + #[test] + fn simple_notebook() { + let code = include_str!("../../../crates/cli_bin/fixtures/notebooks/tiny_nb.ipynb"); + let mut parser = MarzanoNotebookParser::new(&Python::new(None)); + let tree = parser + .parse_file(code, None, &mut AnalysisLogs::default(), FileOrigin::Fresh) + .unwrap(); + + let cursor = tree.root_node().node.walk(); + + for n in traverse(CursorWrapper::new(cursor, code), Order::Pre) { + println!("Node kind: {}", n.node.kind()); + assert!( + !n.node.is_error(), + "Node is an error: {}", + n.node.utf8_text(code.as_bytes()).unwrap() + ); + } + } +} diff --git a/crates/language/src/python.rs b/crates/language/src/python.rs index 3f6d5137f..711043010 100644 --- a/crates/language/src/python.rs +++ b/crates/language/src/python.rs @@ -1,5 +1,8 @@ -use crate::language::{fields_for_nodes, Field, MarzanoLanguage, NodeTypes, SortId, TSLanguage}; -use grit_util::{Ast, AstNode, CodeRange, Language, Replacement}; +use crate::{ + language::{fields_for_nodes, Field, MarzanoLanguage, NodeTypes, SortId, TSLanguage, Tree}, + notebooks::MarzanoNotebookParser, +}; +use grit_util::{Ast, AstNode, CodeRange, Language, Parser, Replacement}; use marzano_util::node_with_source::NodeWithSource; use std::sync::OnceLock; @@ -143,6 +146,10 @@ impl<'a> MarzanoLanguage<'a> for Python { fn metavariable_sort(&self) -> SortId { self.metavariable_sort } + + fn get_parser(&self) -> Box> { + Box::new(MarzanoNotebookParser::new(self)) + } } #[cfg(test)] diff --git a/crates/language/src/sourcemap.rs b/crates/language/src/sourcemap.rs new file mode 100644 index 000000000..76e35f2e6 --- /dev/null +++ b/crates/language/src/sourcemap.rs @@ -0,0 +1,69 @@ +use anyhow::Result; +use grit_util::ByteRange; +use serde_json::json; + +/// A source map is used when the code we are parsing is embedded inside a larger file. +/// For example, we want to focus on the Python code inside a Jupyter notebook. +#[derive(Debug, Clone)] +pub struct EmbeddedSourceMap { + sections: Vec, + /// This is a bit suboptimal, but we assume nobody has tons of embedded files + pub(crate) outer_source: String, +} + +impl EmbeddedSourceMap { + pub fn new(outer_source: &str) -> Self { + Self { + sections: vec![], + outer_source: outer_source.to_string(), + } + } + + pub fn add_section(&mut self, section: SourceMapSection) { + self.sections.push(section); + } + + pub fn fill_with_inner(&self, new_inner_source: &str) -> Result { + let mut outer_source = self.outer_source.clone(); + + for section in &self.sections { + // TODO: actually get the *updated* range + let replacement_code = new_inner_source + .get(section.inner_range.start..section.inner_range.end) + .ok_or(anyhow::anyhow!("Section range is out of bounds"))?; + + let json = section.as_json(replacement_code); + + outer_source.replace_range(section.outer_range.start..section.outer_range.end, &json); + } + + Ok(outer_source) + } +} + +#[derive(Debug, Clone)] +pub struct SourceMapSection { + /// The range of the code within the outer document + pub(crate) outer_range: ByteRange, + /// The range of the code inside the inner document + pub(crate) inner_range: ByteRange, + pub(crate) format: SourceValueFormat, +} + +impl SourceMapSection { + pub fn as_json(&self, code: &str) -> String { + let structure = match self.format { + SourceValueFormat::String => serde_json::Value::String(code.to_string()), + SourceValueFormat::Array => { + json!(vec![code]) + } + }; + structure.to_string() + } +} + +#[derive(Clone, Debug)] +pub enum SourceValueFormat { + String, + Array, +} diff --git a/crates/language/src/target_language.rs b/crates/language/src/target_language.rs index 3ca20f9e9..8003eae57 100644 --- a/crates/language/src/target_language.rs +++ b/crates/language/src/target_language.rs @@ -165,6 +165,7 @@ impl PatternLanguage { Some("inline") => Some(Self::MarkdownInline), _ => Some(Self::MarkdownInline), }, + "ipynb" => Some(Self::Python), "python" => Some(Self::Python), "go" => Some(Self::Go), "rust" => Some(Self::Rust), @@ -197,7 +198,7 @@ impl PatternLanguage { PatternLanguage::Json => &["json"], PatternLanguage::Java => &["java"], PatternLanguage::CSharp => &["cs"], - PatternLanguage::Python => &["py"], + PatternLanguage::Python => &["py", "ipynb"], PatternLanguage::MarkdownBlock => &["md", "mdx", "mdoc"], PatternLanguage::MarkdownInline => &["md", "mdx", "mdoc"], PatternLanguage::Go => &["go"], @@ -252,6 +253,7 @@ impl PatternLanguage { "json" => Some(Self::Json), "java" => Some(Self::Java), "cs" => Some(Self::CSharp), + "ipynb" => Some(Self::Python), "py" => Some(Self::Python), "md" | "mdx" | "mdoc" => Some(Self::MarkdownBlock), "go" => Some(Self::Go), @@ -271,7 +273,7 @@ impl PatternLanguage { self.get_file_extensions().contains(&ext) } - // slightly inneficient but ensures the names are cosnsistent + // slightly inefficient but ensures the names are consistent pub fn language_name(self) -> &'static str { self.try_into() .map(|l: TargetLanguage| l.language_name()) diff --git a/crates/language/src/vue.rs b/crates/language/src/vue.rs index fc6163cd5..893cb0172 100644 --- a/crates/language/src/vue.rs +++ b/crates/language/src/vue.rs @@ -122,6 +122,7 @@ fn append_code_range( let mut cursor = node.walk(); if let Some(mut attributes) = node .child_by_field_name("start_tag") + // nb. This type matches the grammar .map(|n| n.children_by_field_name("atributes", &mut cursor)) { if attributes.any(|n| is_lang_attribute(&n, text, name_array)) { diff --git a/crates/wasm-bindings/src/match_pattern.rs b/crates/wasm-bindings/src/match_pattern.rs index 50dead7f9..0ceac048a 100644 --- a/crates/wasm-bindings/src/match_pattern.rs +++ b/crates/wasm-bindings/src/match_pattern.rs @@ -83,10 +83,8 @@ pub async fn parse_input_files_internal( get_parsed_pattern(&pattern, lib_paths, lib_contents, parser).await?; let node = tree.root_node(); let fields = GRIT_NODE_TYPES - .get_or_init(|| fields_for_nodes(&GRIT_LANGUAGE.get().unwrap(), NODE_TYPES_STRING)); - let grit_node_types = GritNodeTypes { - node_types: &fields, - }; + .get_or_init(|| fields_for_nodes(GRIT_LANGUAGE.get().unwrap(), NODE_TYPES_STRING)); + let grit_node_types = GritNodeTypes { node_types: fields }; let parsed_pattern = tree_sitter_node_to_json(&node.node, &pattern, &grit_node_types).to_string(); @@ -350,7 +348,7 @@ async fn setup_grit_parser() -> anyhow::Result { let new_lang = get_lang(&lang_path).await?; let _language_already_set = GRIT_LANGUAGE.set(new_lang); let _ = GRIT_NODE_TYPES.set(fields_for_nodes( - &GRIT_LANGUAGE.get().unwrap(), + GRIT_LANGUAGE.get().unwrap(), NODE_TYPES_STRING, )); GRIT_LANGUAGE