From 796e3a988acf0edbfae1623b04e3a439a921e3e3 Mon Sep 17 00:00:00 2001 From: Akarsh Gupta <69024958+akarshgupta7@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:43:14 -0800 Subject: [PATCH] Added ragas to compute answer metrics for evaluation. (#1039) * Removed duplicate code in query execution. * Added nltk download logic to support string metrics. * Add ragas to imports. * Added ragas to imports. * Added string metrics for evaluation. * Remove dthe unwanted break statement. * Moves the import to top of file. * Moved the scorer definitions to the init function. * Refactoring. * Removed unused imports. * Moved async calls to the outer function. * Refactor to add error handling and change function names. --- apps/query-eval/poetry.lock | 585 ++++++++++++++++++- apps/query-eval/pyproject.toml | 2 +- apps/query-eval/queryeval/driver.py | 153 ++++- apps/query-eval/queryeval/main.py | 14 +- apps/query-eval/queryeval/queryeval_types.py | 3 +- 5 files changed, 724 insertions(+), 33 deletions(-) diff --git a/apps/query-eval/poetry.lock b/apps/query-eval/poetry.lock index 0b588f9a1..89bed4f6b 100644 --- a/apps/query-eval/poetry.lock +++ b/apps/query-eval/poetry.lock @@ -1,4 +1,29 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "absl-py" +version = "2.1.0" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = false +python-versions = ">=3.7" +files = [ + {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, + {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, +] + +[[package]] +name = "adagio" +version = "0.2.6" +description = "The Dag IO Framework for Fugue projects" +optional = false +python-versions = ">=3.8" +files = [ + {file = "adagio-0.2.6-py3-none-any.whl", hash = "sha256:1bb8317d41bfff8b11373bc03c9859ff166c498214bb2b7ce1e21638c0babb2c"}, + {file = "adagio-0.2.6.tar.gz", hash = "sha256:0c32768f3aba0e05273b36f9420a482034f2510f059171040d7e98ba34128d7a"}, +] + +[package.dependencies] +triad = ">=0.6.1" [[package]] name = "aiohappyeyeballs" @@ -1393,6 +1418,40 @@ files = [ marshmallow = ">=3.18.0,<4.0.0" typing-inspect = ">=0.4.0,<1" +[[package]] +name = "datacompy" +version = "0.14.4" +description = "Dataframe comparison in Python" +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "datacompy-0.14.4-py3-none-any.whl", hash = "sha256:e6b3edc21d7ab9ea3bbba5bf5111b47be23841050876e1362770e40f5b68546f"}, + {file = "datacompy-0.14.4.tar.gz", hash = "sha256:8fd0a5bd6146c1efe43baae3495107b68c0cbd74688610ae1a26deba094a4476"}, +] + +[package.dependencies] +fugue = ">=0.8.7,<=0.9.1" +numpy = ">=1.22.0,<=2.1.2" +ordered-set = ">=4.0.2,<=4.1.0" +pandas = ">=0.25.0,<=2.2.3" +polars = ">=0.20.4,<=1.12.0" + +[package.extras] +build = ["build", "twine", "wheel"] +dask = ["fugue[dask]"] +dev = ["datacompy[build]", "datacompy[docs]", "datacompy[duckdb]", "datacompy[qa]", "datacompy[snowflake]", "datacompy[spark]", "datacompy[tests-snowflake]", "datacompy[tests-spark]", "datacompy[tests]"] +dev-no-snowflake = ["datacompy[build]", "datacompy[docs]", "datacompy[duckdb]", "datacompy[qa]", "datacompy[spark]", "datacompy[tests-spark]", "datacompy[tests]"] +docs = ["furo", "myst-parser", "sphinx"] +duckdb = ["fugue[duckdb]"] +edgetest = ["edgetest", "edgetest-conda"] +qa = ["mypy", "pandas-stubs", "pre-commit", "ruff (==0.5.7)"] +ray = ["fugue[ray]"] +snowflake = ["snowflake-connector-python", "snowflake-snowpark-python"] +spark = ["pyspark[connect] (>=3.1.1)", "pyspark[connect] (>=3.4)"] +tests = ["pytest", "pytest-cov"] +tests-snowflake = ["snowflake-snowpark-python[localtest]"] +tests-spark = ["pytest", "pytest-cov", "pytest-spark"] + [[package]] name = "datasets" version = "2.19.2" @@ -1481,6 +1540,23 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "deprecated" +version = "1.2.15" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +files = [ + {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, + {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "jinja2 (>=3.0.3,<3.1.0)", "setuptools", "sphinx (<2)", "tox"] + [[package]] name = "dill" version = "0.3.8" @@ -1496,6 +1572,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "dirtyjson" +version = "1.0.8" +description = "JSON decoder for Python that can extract data from the muck" +optional = false +python-versions = "*" +files = [ + {file = "dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53"}, + {file = "dirtyjson-1.0.8.tar.gz", hash = "sha256:90ca4a18f3ff30ce849d100dcf4a003953c79d3a2348ef056f1d9c22231a25fd"}, +] + [[package]] name = "diskcache" version = "5.6.3" @@ -1698,6 +1785,17 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "filetype" +version = "1.2.0" +description = "Infer file type and MIME type of any file/buffer. No external dependencies." +optional = false +python-versions = "*" +files = [ + {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, + {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, +] + [[package]] name = "fire" version = "0.7.0" @@ -1883,6 +1981,25 @@ files = [ {file = "frozenlist-1.5.0.tar.gz", hash = "sha256:81d5af29e61b9c8348e876d442253723928dce6433e0e76cd925cd83f1b4b817"}, ] +[[package]] +name = "fs" +version = "2.4.16" +description = "Python's filesystem abstraction layer" +optional = false +python-versions = "*" +files = [ + {file = "fs-2.4.16-py2.py3-none-any.whl", hash = "sha256:660064febbccda264ae0b6bace80a8d1be9e089e0a5eb2427b7d517f9a91545c"}, + {file = "fs-2.4.16.tar.gz", hash = "sha256:ae97c7d51213f4b70b6a958292530289090de3a7e15841e108fbe144f069d313"}, +] + +[package.dependencies] +appdirs = ">=1.4.3,<1.5.0" +setuptools = "*" +six = ">=1.10,<2.0" + +[package.extras] +scandir = ["scandir (>=1.5,<2.0)"] + [[package]] name = "fsspec" version = "2024.2.0" @@ -1921,6 +2038,33 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "fugue" +version = "0.9.1" +description = "An abstraction layer for distributed computation" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fugue-0.9.1-py3-none-any.whl", hash = "sha256:5b91e55e6f243af6e2b901dc37914d954d8f0231627b68007850879f8848a3a3"}, + {file = "fugue-0.9.1.tar.gz", hash = "sha256:fb0f9a4780147ac8438be96efc50593e2d771d1cbf528ac56d3bcecd39915b50"}, +] + +[package.dependencies] +adagio = ">=0.2.4" +triad = ">=0.9.7" + +[package.extras] +all = ["dask-sql", "dask[dataframe,distributed] (>=2023.5.0)", "duckdb (>=0.5.0)", "fugue-sql-antlr (>=0.2.0)", "ibis-framework", "ipython (>=7.10.0)", "jinja2", "jupyterlab", "notebook", "pandas (>=2.0.2,<2.2)", "polars", "pyarrow (>=6.0.1)", "pyspark (>=3.1.1)", "qpd (>=0.4.4)", "ray[data] (>=2.5.0)", "sqlglot"] +cpp-sql-parser = ["fugue-sql-antlr[cpp] (>=0.2.0)"] +dask = ["dask[dataframe,distributed] (>=2023.5.0)", "dask[dataframe,distributed] (>=2024.4.0)", "pandas (>=2.0.2)", "pyarrow (>=7.0.0)"] +duckdb = ["duckdb (>=0.5.0)", "fugue-sql-antlr (>=0.2.0)", "jinja2", "numpy", "qpd (>=0.4.4)", "sqlglot"] +ibis = ["fugue-sql-antlr (>=0.2.0)", "ibis-framework", "jinja2", "pandas (<2.2)", "qpd (>=0.4.4)", "sqlglot"] +notebook = ["ipython (>=7.10.0)", "jupyterlab", "notebook"] +polars = ["polars"] +ray = ["duckdb (>=0.5.0)", "pandas (<2.2)", "pyarrow (>=7.0.0)", "ray[data] (>=2.5.0)"] +spark = ["pyspark (>=3.1.1)"] +sql = ["fugue-sql-antlr (>=0.2.0)", "jinja2", "qpd (>=0.4.4)", "sqlglot"] + [[package]] name = "google-api-core" version = "2.21.0" @@ -3093,6 +3237,297 @@ dev = ["black", "flake8", "isort", "pre-commit", "pyproject-flake8"] doc = ["myst-parser", "sphinx", "sphinx-book-theme"] test = ["coverage", "pytest", "pytest-cov"] +[[package]] +name = "llama-cloud" +version = "0.1.5" +description = "" +optional = false +python-versions = "<4,>=3.8" +files = [ + {file = "llama_cloud-0.1.5-py3-none-any.whl", hash = "sha256:15605022520d04bd6ef6a46c0cbde833f301d652286d34fca02b4c44e2a7a2aa"}, + {file = "llama_cloud-0.1.5.tar.gz", hash = "sha256:8ce1db36754a6a46c8511561dbc040a2e89ba4ca1cf4edfb6ce382a5240f6cb6"}, +] + +[package.dependencies] +httpx = ">=0.20.0" +pydantic = ">=1.10" + +[[package]] +name = "llama-index" +version = "0.11.22" +description = "Interface between LLMs and your data" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index-0.11.22-py3-none-any.whl", hash = "sha256:bda98d925dfbab4b76c07cc61b59bb5920e15e685efd9fbf3a0cd33f1f465f10"}, + {file = "llama_index-0.11.22.tar.gz", hash = "sha256:8d8a7838a7fcc733fc7a262ef3709df001c3021cb42843c8e9da8d244e5355e1"}, +] + +[package.dependencies] +llama-index-agent-openai = ">=0.3.4,<0.4.0" +llama-index-cli = ">=0.3.1,<0.4.0" +llama-index-core = ">=0.11.22,<0.12.0" +llama-index-embeddings-openai = ">=0.2.4,<0.3.0" +llama-index-indices-managed-llama-cloud = ">=0.3.0" +llama-index-legacy = ">=0.9.48,<0.10.0" +llama-index-llms-openai = ">=0.2.10,<0.3.0" +llama-index-multi-modal-llms-openai = ">=0.2.0,<0.3.0" +llama-index-program-openai = ">=0.2.0,<0.3.0" +llama-index-question-gen-openai = ">=0.2.0,<0.3.0" +llama-index-readers-file = ">=0.2.0,<0.3.0" +llama-index-readers-llama-parse = ">=0.3.0" +nltk = ">3.8.1" + +[[package]] +name = "llama-index-agent-openai" +version = "0.3.4" +description = "llama-index agent openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_agent_openai-0.3.4-py3-none-any.whl", hash = "sha256:3720ce9bb12417a99a3fe84e52cce23e762b13f88a2dfc4292c76f4df9b26b4a"}, + {file = "llama_index_agent_openai-0.3.4.tar.gz", hash = "sha256:80e3408d97121bebca3fa3ffd14b51285870c1c3c73d4ee04d3d18cfe6040466"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +llama-index-llms-openai = ">=0.2.9,<0.3.0" +openai = ">=1.14.0" + +[[package]] +name = "llama-index-cli" +version = "0.3.1" +description = "llama-index cli" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_cli-0.3.1-py3-none-any.whl", hash = "sha256:2111fbb6973f5b1eabce0d6cca3986499f0f2f625b13d7f48269a49c64c027d4"}, + {file = "llama_index_cli-0.3.1.tar.gz", hash = "sha256:1890dd687cf440f3651365a549e303363162c167b8efbd87a3aa10058d6d5c77"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +llama-index-embeddings-openai = ">=0.2.0,<0.3.0" +llama-index-llms-openai = ">=0.2.0,<0.3.0" + +[[package]] +name = "llama-index-core" +version = "0.11.23" +description = "Interface between LLMs and your data" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_core-0.11.23-py3-none-any.whl", hash = "sha256:25a0cb4a055bfb348655ca4acd1b475529bd8537a7b81874ef14ed13f56e06c1"}, + {file = "llama_index_core-0.11.23.tar.gz", hash = "sha256:e150859696a0eae169fe19323f46e9a31af2c12c3182012e4d0353ea8eb06d24"}, +] + +[package.dependencies] +aiohttp = ">=3.8.6,<4.0.0" +dataclasses-json = "*" +deprecated = ">=1.2.9.3" +dirtyjson = ">=1.0.8,<2.0.0" +filetype = ">=1.2.0,<2.0.0" +fsspec = ">=2023.5.0" +httpx = "*" +nest-asyncio = ">=1.5.8,<2.0.0" +networkx = ">=3.0" +nltk = ">3.8.1" +numpy = "<2.0.0" +pillow = ">=9.0.0" +pydantic = ">=2.7.0,<3.0.0" +PyYAML = ">=6.0.1" +requests = ">=2.31.0" +SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} +tenacity = ">=8.2.0,<8.4.0 || >8.4.0,<9.0.0" +tiktoken = ">=0.3.3" +tqdm = ">=4.66.1,<5.0.0" +typing-extensions = ">=4.5.0" +typing-inspect = ">=0.8.0" +wrapt = "*" + +[[package]] +name = "llama-index-embeddings-openai" +version = "0.2.5" +description = "llama-index embeddings openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_embeddings_openai-0.2.5-py3-none-any.whl", hash = "sha256:823c8311e556349ba19dda408a64a314fa3dafe0e5759709c54d33a0269aa6ba"}, + {file = "llama_index_embeddings_openai-0.2.5.tar.gz", hash = "sha256:0047dd71d747068645ed728c29312aa91b65bbe4c6142180034c64dfc5c6f6e8"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +openai = ">=1.1.0" + +[[package]] +name = "llama-index-indices-managed-llama-cloud" +version = "0.6.0" +description = "llama-index indices llama-cloud integration" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "llama_index_indices_managed_llama_cloud-0.6.0-py3-none-any.whl", hash = "sha256:18a3bbb386c4fbda8883cf40339bde402637e4cd5e06bcf3870d8c174b9baa3a"}, + {file = "llama_index_indices_managed_llama_cloud-0.6.0.tar.gz", hash = "sha256:fe32aecb87ffd81eb824fc64509cc991c3cde574455e53e73a4dbe30961c4f21"}, +] + +[package.dependencies] +llama-cloud = ">=0.1.5" +llama-index-core = ">=0.11.13.post1,<0.12.0" + +[[package]] +name = "llama-index-legacy" +version = "0.9.48.post4" +description = "Interface between LLMs and your data" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_legacy-0.9.48.post4-py3-none-any.whl", hash = "sha256:4b817d7c343fb5f7f00c4410eff519f320013b8d5f24c4fedcf270c471f92038"}, + {file = "llama_index_legacy-0.9.48.post4.tar.gz", hash = "sha256:f8a9764e7e134a52bfef5e53d2d62561bfc01fc09874c51cc001df6f5302ae30"}, +] + +[package.dependencies] +aiohttp = ">=3.8.6,<4.0.0" +dataclasses-json = "*" +deprecated = ">=1.2.9.3" +dirtyjson = ">=1.0.8,<2.0.0" +fsspec = ">=2023.5.0" +httpx = "*" +nest-asyncio = ">=1.5.8,<2.0.0" +networkx = ">=3.0" +nltk = ">=3.8.1" +numpy = "*" +openai = ">=1.1.0" +pandas = "*" +requests = ">=2.31.0" +SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} +tenacity = ">=8.2.0,<9.0.0" +tiktoken = ">=0.3.3" +typing-extensions = ">=4.5.0" +typing-inspect = ">=0.8.0" + +[package.extras] +gradientai = ["gradientai (>=1.4.0)"] +html = ["beautifulsoup4 (>=4.12.2,<5.0.0)"] +langchain = ["langchain (>=0.0.303)"] +local-models = ["optimum[onnxruntime] (>=1.13.2,<2.0.0)", "sentencepiece (>=0.1.99,<0.2.0)", "transformers[torch] (>=4.33.1,<5.0.0)"] +postgres = ["asyncpg (>=0.28.0,<0.29.0)", "pgvector (>=0.1.0,<0.2.0)", "psycopg2-binary (>=2.9.9,<3.0.0)"] +query-tools = ["guidance (>=0.0.64,<0.0.65)", "jsonpath-ng (>=1.6.0,<2.0.0)", "lm-format-enforcer (>=0.4.3,<0.5.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "scikit-learn", "spacy (>=3.7.1,<4.0.0)"] + +[[package]] +name = "llama-index-llms-openai" +version = "0.2.16" +description = "llama-index llms openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_llms_openai-0.2.16-py3-none-any.whl", hash = "sha256:413466acbb894bd81f8dab2037f595e92392d869eec6d8274a16d43123cac8b6"}, + {file = "llama_index_llms_openai-0.2.16.tar.gz", hash = "sha256:7c666dd27056c278a079ff45d53f1fbfc8ed363764aa7baeee2e03df47f9072a"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.7,<0.12.0" +openai = ">=1.40.0,<2.0.0" + +[[package]] +name = "llama-index-multi-modal-llms-openai" +version = "0.2.3" +description = "llama-index multi-modal-llms openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_multi_modal_llms_openai-0.2.3-py3-none-any.whl", hash = "sha256:96b36beb2c3fca4faca80c59ecf7c6c6629ecdb96c288ef89777b592ec43f872"}, + {file = "llama_index_multi_modal_llms_openai-0.2.3.tar.gz", hash = "sha256:8eb9b7f1ff3956ef0979e21bc83e6a885e40987b7199f195e46525d06e3ae402"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +llama-index-llms-openai = ">=0.2.11,<0.3.0" + +[[package]] +name = "llama-index-program-openai" +version = "0.2.0" +description = "llama-index program openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_program_openai-0.2.0-py3-none-any.whl", hash = "sha256:2e10d0c8f21af2e9443eb79e81bb31e7b73835b7c7bbd7ddf20e0a9c846cd368"}, + {file = "llama_index_program_openai-0.2.0.tar.gz", hash = "sha256:4139935541c011257fbfeb9662b3bf1237b729ef4b1c8f4ddf5b6789d2374ac4"}, +] + +[package.dependencies] +llama-index-agent-openai = ">=0.3.0,<0.4.0" +llama-index-core = ">=0.11.0,<0.12.0" +llama-index-llms-openai = ">=0.2.0,<0.3.0" + +[[package]] +name = "llama-index-question-gen-openai" +version = "0.2.0" +description = "llama-index question_gen openai integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_question_gen_openai-0.2.0-py3-none-any.whl", hash = "sha256:a16e68fc5434e9a793f1dfd0cc0354ee19afd167f1d499403b0085b11c5406c0"}, + {file = "llama_index_question_gen_openai-0.2.0.tar.gz", hash = "sha256:3dde1cecbd651000639c20031d7ea23334276aabb181cac40ff424f35e10465e"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +llama-index-llms-openai = ">=0.2.0,<0.3.0" +llama-index-program-openai = ">=0.2.0,<0.3.0" + +[[package]] +name = "llama-index-readers-file" +version = "0.2.2" +description = "llama-index readers file integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_readers_file-0.2.2-py3-none-any.whl", hash = "sha256:ffec878771c1e7575afb742887561059bcca77b97a81c1c1be310ebb73f10f46"}, + {file = "llama_index_readers_file-0.2.2.tar.gz", hash = "sha256:48459f90960b863737147b66ed83afec9ce8984f8eda2561b6d2500214365db2"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.12.3,<5.0.0" +llama-index-core = ">=0.11.0,<0.12.0" +pandas = "*" +pypdf = ">=4.0.1,<5.0.0" +striprtf = ">=0.0.26,<0.0.27" + +[package.extras] +pymupdf = ["pymupdf (>=1.23.21,<2.0.0)"] + +[[package]] +name = "llama-index-readers-llama-parse" +version = "0.3.0" +description = "llama-index readers llama-parse integration" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_index_readers_llama_parse-0.3.0-py3-none-any.whl", hash = "sha256:1973cc710dbd5e110c7500c9983ecb45787ad1ff92e6b2113f94a57cf48f3038"}, + {file = "llama_index_readers_llama_parse-0.3.0.tar.gz", hash = "sha256:a5feada0895714dcc41d65dd512c1c38cf70d8ae19947cff82b80d58e6aa367e"}, +] + +[package.dependencies] +llama-index-core = ">=0.11.0,<0.12.0" +llama-parse = ">=0.5.0" + +[[package]] +name = "llama-parse" +version = "0.5.15" +description = "Parse files into RAG-Optimized formats." +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "llama_parse-0.5.15-py3-none-any.whl", hash = "sha256:7a3506c7d3ae5a8e68c70a457a7213d2698e26abcef1d7a989eb9771cd73ae60"}, + {file = "llama_parse-0.5.15.tar.gz", hash = "sha256:ecb009f71c8b4c657085ca81808a922c80785810e38b10f3b46f03cfd29ba92a"}, +] + +[package.dependencies] +click = ">=8.1.7,<9.0.0" +llama-index-core = ">=0.11.0" +pydantic = "!=2.10" + [[package]] name = "lmdb" version = "1.5.1" @@ -3965,6 +4400,31 @@ files = [ [package.extras] test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "pytest (>=4.5.0)", "pytest-cov (>=2.7.1)", "pytest-runner (>=5.1)", "pytest-virtualenv (>=1.7.0)", "virtualenv (>=15.0.3)"] +[[package]] +name = "nltk" +version = "3.9.1" +description = "Natural Language Toolkit" +optional = false +python-versions = ">=3.8" +files = [ + {file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"}, + {file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"}, +] + +[package.dependencies] +click = "*" +joblib = "*" +regex = ">=2021.8.3" +tqdm = "*" + +[package.extras] +all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"] +corenlp = ["requests"] +machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"] +plot = ["matplotlib"] +tgrep = ["pyparsing"] +twitter = ["twython"] + [[package]] name = "numpy" version = "1.26.4" @@ -4718,6 +5178,47 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "1.12.0" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.9" +files = [ + {file = "polars-1.12.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:8f3c4e4e423c373dda07b4c8a7ff12aa02094b524767d0ca306b1eba67f2d99e"}, + {file = "polars-1.12.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:aa6f9862f0cec6353243920d9b8d858c21ec8f25f91af203dea6ff91980e140d"}, + {file = "polars-1.12.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afb03647b5160737d2119532ee8ffe825de1d19d87f81bbbb005131786f7d59b"}, + {file = "polars-1.12.0-cp39-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:ea96aba5eb3dab8f0e6abf05ab3fc2136b329261860ef8661d20f5456a2d78e0"}, + {file = "polars-1.12.0-cp39-abi3-win_amd64.whl", hash = "sha256:a228a4b320a36d03a9ec9dfe7241b6d80a2f119b2dceb1da953166655e4cf43c"}, + {file = "polars-1.12.0.tar.gz", hash = "sha256:fb5c92de1a8f7d0a3f923fe48ea89eb518bdf55315ae917012350fa072bd64f4"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.15.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +gpu = ["cudf-polars-cu12"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["altair (>=5.4.0)"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "prometheus-client" version = "0.21.0" @@ -5649,30 +6150,40 @@ files = [ [[package]] name = "ragas" -version = "0.1.21" +version = "0.2.6" description = "" optional = false python-versions = "*" files = [ - {file = "ragas-0.1.21-py3-none-any.whl", hash = "sha256:c5be4dbe3d4a90a62298889aaef8941516c41e0708a8fe942c9ffa9395cf244d"}, - {file = "ragas-0.1.21.tar.gz", hash = "sha256:4ded52375d3710a2a2dd23d550fd09d0e26485fc435b14a413174f9d7b45d252"}, + {file = "ragas-0.2.6-py3-none-any.whl", hash = "sha256:2d40a6af196df7346486e2eeb203bb0a542efa0827e839812f6c66123fd3319f"}, + {file = "ragas-0.2.6.tar.gz", hash = "sha256:877e723e4bbf29eab8e1b12f7bf6f63bb2145d63ea4c3ce21620b14f9dbfb421"}, ] [package.dependencies] appdirs = "*" +datacompy = {version = "*", optional = true, markers = "extra == \"all\""} datasets = "*" -langchain = "<0.3" -langchain-community = "<0.3" -langchain-core = "<0.3" +langchain = "*" +langchain-community = "*" +langchain-core = "*" langchain-openai = "*" +llama-index = {version = "*", optional = true, markers = "extra == \"all\""} nest-asyncio = "*" +nltk = {version = "*", optional = true, markers = "extra == \"all\""} numpy = "*" openai = ">1" +pandas = {version = "*", optional = true, markers = "extra == \"all\""} +pydantic = ">=2" pysbd = ">=0.3.4" +rapidfuzz = {version = "*", optional = true, markers = "extra == \"all\""} +rouge-score = {version = "*", optional = true, markers = "extra == \"all\""} +sentence-transformers = {version = "*", optional = true, markers = "extra == \"all\""} tiktoken = "*" +transformers = {version = "*", optional = true, markers = "extra == \"all\""} [package.extras] -all = ["sentence-transformers", "transformers"] +all = ["datacompy", "llama-index", "nltk", "pandas", "rapidfuzz", "rouge-score", "sentence-transformers", "transformers"] +docs = ["mkdocs (>=1.6.1)", "mkdocs-autorefs", "mkdocs-gen-files", "mkdocs-git-committers-plugin-2", "mkdocs-git-revision-date-localized-plugin", "mkdocs-glightbox", "mkdocs-literate-nav", "mkdocs-material", "mkdocs-material[imaging]", "mkdocs-section-index", "mkdocstrings[python]"] [[package]] name = "rapidfuzz" @@ -6031,6 +6542,22 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.1 [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "rouge-score" +version = "0.1.2" +description = "Pure python implementation of ROUGE-1.5.5." +optional = false +python-versions = ">=3.7" +files = [ + {file = "rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04"}, +] + +[package.dependencies] +absl-py = "*" +nltk = "*" +numpy = "*" +six = ">=1.14.0" + [[package]] name = "rpds-py" version = "0.20.0" @@ -6453,6 +6980,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -6789,7 +7321,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", optional = true, markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") or extra == \"asyncio\""} typing-extensions = ">=4.6.0" [package.extras] @@ -6836,6 +7368,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "striprtf" +version = "0.0.26" +description = "A simple library to convert rtf to text" +optional = false +python-versions = "*" +files = [ + {file = "striprtf-0.0.26-py3-none-any.whl", hash = "sha256:8c8f9d32083cdc2e8bfb149455aa1cc5a4e0a035893bedc75db8b73becb3a1bb"}, + {file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"}, +] + [[package]] name = "structlog" version = "24.4.0" @@ -7483,6 +8026,28 @@ torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] +[[package]] +name = "triad" +version = "0.9.8" +description = "A collection of python utils for Fugue projects" +optional = false +python-versions = ">=3.8" +files = [ + {file = "triad-0.9.8-py3-none-any.whl", hash = "sha256:2c0ba7d83977c6d4e7b59e3cc70727f858014ef7676c62d184aa8e63f7bef5de"}, + {file = "triad-0.9.8.tar.gz", hash = "sha256:5b67673124891981daf8afbab44b2e6358932ca35ef3ff38a25bc3e0f6f03f17"}, +] + +[package.dependencies] +fs = "*" +fsspec = ">=2022.5.0" +numpy = "*" +pandas = ">=1.3.5" +pyarrow = ">=6.0.1" +six = "*" + +[package.extras] +ciso8601 = ["ciso8601"] + [[package]] name = "triton" version = "2.3.1" @@ -8115,4 +8680,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.9.7 || >3.9.7,<3.13" -content-hash = "0579a28ee82c298ceb1dde7db59b7f470be3d0c0b8d54d9e7a45e34910e2b083" +content-hash = "be9d089a12a911ac9f4aaa145b64789e448c43e6ebad260479dd16b716eb620d" diff --git a/apps/query-eval/pyproject.toml b/apps/query-eval/pyproject.toml index d81e46996..89fb9971e 100644 --- a/apps/query-eval/pyproject.toml +++ b/apps/query-eval/pyproject.toml @@ -8,7 +8,7 @@ packages = [{include = "queryeval"}] [tool.poetry.dependencies] python = ">=3.9,<3.9.7 || >3.9.7,<3.13" sycamore-ai = { path = "../../lib/sycamore", develop = true, extras = ["opensearch", "local-inference"] } -ragas = "^0.1.5" +ragas = { version = "^0.2.6", extras = ["all"] } pydantic-yaml = "^1.3.0" pydantic = "^2.8.2" rich = "^13.7.1" diff --git a/apps/query-eval/queryeval/driver.py b/apps/query-eval/queryeval/driver.py index 6f1217371..4472d5ff3 100644 --- a/apps/query-eval/queryeval/driver.py +++ b/apps/query-eval/queryeval/driver.py @@ -23,9 +23,39 @@ DocSetSummary, ) +import asyncio +from ragas.dataset_schema import SingleTurnSample +from ragas.metrics import BleuScore, RougeScore, SemanticSimilarity +from ragas.embeddings.base import HuggingfaceEmbeddings, LangchainEmbeddingsWrapper + console = Console() +def compute_text_metrics( + sample: SingleTurnSample, + rouge_scorer: RougeScore, + bleu_scorer: BleuScore, + semantic_similarity_scorer: SemanticSimilarity, +): + d = {} + try: + d["rouge"] = asyncio.run(rouge_scorer.single_turn_ascore(sample)) + except Exception: + tb = traceback.format_exc() + console.print(f"[red]Error computing ROUGE score: {tb}") + try: + d["bleu"] = asyncio.run(bleu_scorer.single_turn_ascore(sample)) + except Exception: + tb = traceback.format_exc() + console.print(f"[red]Error computing BLEU score: {tb}") + try: + d["semantic_similarity"] = asyncio.run(semantic_similarity_scorer.single_turn_ascore(sample)) + except Exception: + tb = traceback.format_exc() + console.print(f"[red]Error computing semantic similarity score: {tb}") + return d + + class QueryEvalDriver: """Class to run Sycamore Query evaluations. @@ -91,7 +121,10 @@ def __init__( # Configure logging. if self.config.config.log_file: - os.makedirs(os.path.dirname(os.path.abspath(self.config.config.log_file)), exist_ok=True) + os.makedirs( + os.path.dirname(os.path.abspath(self.config.config.log_file)), + exist_ok=True, + ) configure_logging(logfile=self.config.config.log_file, log_level=logging.INFO) if not self.config.config.index: @@ -101,7 +134,10 @@ def __init__( if self.config.config.results_file: console.print(f"Writing results to: {self.config.config.results_file}") - os.makedirs(os.path.dirname(os.path.abspath(self.config.config.results_file)), exist_ok=True) + os.makedirs( + os.path.dirname(os.path.abspath(self.config.config.results_file)), + exist_ok=True, + ) # Read results file if it exists. if ( @@ -145,6 +181,14 @@ def __init__( if self.config.examples: self.examples = self.config.examples + # Define scorers. + self.bleu_scorer = BleuScore() + self.rouge_scorer = RougeScore() + self.semantic_similarity_scorer = SemanticSimilarity() + self.semantic_similarity_scorer.embeddings = LangchainEmbeddingsWrapper( + HuggingfaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") + ) + @staticmethod def read_input_file(input_file_path: str) -> QueryEvalInputFile: """Read the given input file.""" @@ -317,18 +361,9 @@ def do_query(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalR console.print(f":white_check_mark: Result: {result.result}") return result - def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalResult: - """Run query evaluation.""" - if self.config.config: - if self.config.config.dry_run: - console.print("[yellow]:point_right: Dry run: skipping eval") - return result - elif not self._check_tags_match(query): - console.print("[yellow]:point_right: Skipping query due to tag mismatch") - return result - - metrics = result.metrics or QueryEvalMetrics() - # Evalute query plans + def get_query_plan_metrics( + self, query: QueryEvalQuery, result: QueryEvalResult, metrics: QueryEvalMetrics + ) -> QueryEvalMetrics: if not query.expected_plan: console.print("[yellow]:construction: No expected query plan found, skipping.. ") elif not result.plan: @@ -350,11 +385,15 @@ def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalRe console.print(f"Actual node: [red]{diff.node_b!r}") console.print() metrics.plan_similarity = max( - 0.0, (len(query.expected_plan.nodes) - len(plan_diff)) / len(query.expected_plan.nodes) + 0.0, + (len(query.expected_plan.nodes) - len(plan_diff)) / len(query.expected_plan.nodes), ) metrics.plan_diff_count = len(plan_diff) + return metrics - # Evaluate doc retrieval + def get_retrieval_metrics( + self, query: QueryEvalQuery, result: QueryEvalResult, metrics: QueryEvalMetrics + ) -> QueryEvalMetrics: if not query.expected_docs: console.print("[yellow]:construction: No expected document list found, skipping.. ") elif not result.retrieved_docs: @@ -368,11 +407,65 @@ def do_eval(self, query: QueryEvalQuery, result: QueryEvalResult) -> QueryEvalRe console.print("[red]:x: Documents retrieved don't include all expected docs") console.print(f"Missing docs: {expected_doc_set - retrieved_doc_set})") metrics.doc_retrieval_recall = len(retrieved_doc_set & expected_doc_set) / len(expected_doc_set) - metrics.doc_retrieval_precision = len(retrieved_doc_set & expected_doc_set) / len(result.retrieved_docs) + metrics.doc_retrieval_precision = len(retrieved_doc_set & expected_doc_set) / len(retrieved_doc_set) + return metrics + + def get_answer_metrics( + self, query: QueryEvalQuery, result: QueryEvalResult, metrics: QueryEvalMetrics + ) -> QueryEvalMetrics: + if not query.expected: + console.print("[yellow]:construction: No expected response found, skipping.. ") + elif not result.result: + console.print( + "[yellow] No query execution result available, skipping..", + style="italic", + ) + else: + if isinstance(query.expected, str) and isinstance(result.result, str): + sample = SingleTurnSample( + response=result.result, + reference=query.expected, + ) + scores = compute_text_metrics( + sample, + self.rouge_scorer, + self.bleu_scorer, + self.semantic_similarity_scorer, + ) + metrics.bleu_score = scores.get("bleu", None) + metrics.rouge_score = scores.get("rouge", None) + metrics.similarity_score = scores.get("semantic_similarity", None) + console.print("[green]✔ String metrics computed.") + elif isinstance(query.expected, str) and isinstance(result.result, DocSetSummary): + pass + else: + console.print("[red]:x: Unsupported expected/response type, skipping.. ") + return metrics + + def do_eval( + self, + query: QueryEvalQuery, + result: QueryEvalResult, + ) -> QueryEvalResult: + """Run query evaluation.""" + if self.config.config: + if self.config.config.dry_run: + console.print("[yellow]:point_right: Dry run: skipping eval") + return result + elif not self._check_tags_match(query): + console.print("[yellow]:point_right: Skipping query due to tag mismatch") + return result + + metrics = result.metrics or QueryEvalMetrics() - # Evaluate result - if not result.result: - console.print("[yellow] No query execution result available, skipping..", style="italic") + # Evalute query plans + metrics = self.get_query_plan_metrics(query, result, metrics) + + # Evaluate doc retrieval + metrics = self.get_retrieval_metrics(query, result, metrics) + + # Evaluate string metrics + metrics = self.get_answer_metrics(query, result, metrics) result.metrics = metrics @@ -434,6 +527,7 @@ def print_metrics_summary(self): / len(self.results_map) ) ) + # Evaluate doc retrieval correct_retrievals = sum( 1 for result in self.results_map.values() if result.metrics.doc_retrieval_recall == 1.0 @@ -446,9 +540,27 @@ def print_metrics_summary(self): if result.metrics.doc_retrieval_precision ) / expected_retrievals + if expected_retrievals + else 0 ) console.print(f"Successful doc retrievals: {correct_retrievals}/{expected_retrievals}") console.print(f"Average precision: {average_precision}") + + # String metrics + bleu_scores = [result.metrics.bleu_score for result in self.results_map.values() if result.metrics.bleu_score] + rouge_scores = [ + result.metrics.rouge_score for result in self.results_map.values() if result.metrics.rouge_score + ] + similarity_scores = [ + result.metrics.similarity_score for result in self.results_map.values() if result.metrics.similarity_score + ] + avg_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0 + avg_rouge_score = sum(rouge_scores) / len(rouge_scores) if rouge_scores else 0 + avg_similarity_score = sum(similarity_scores) / len(similarity_scores) if similarity_scores else 0 + console.print(f"Avg. BLEU score: {avg_bleu_score}") + console.print(f"Avg. ROUGE score: {avg_rouge_score}") + console.print(f"Avg. Semantic similarity score: {avg_similarity_score}") + # TODO: Query execution metrics console.print("Query result correctness: not implemented") @@ -488,4 +600,5 @@ def run(self): console.print(f"[red]Error: {tb}") result.error = f"Error: {tb}" self.write_results_file() + self.print_metrics_summary() console.print(":tada: Done!") diff --git a/apps/query-eval/queryeval/main.py b/apps/query-eval/queryeval/main.py index 5a06bce44..c8fd618ed 100644 --- a/apps/query-eval/queryeval/main.py +++ b/apps/query-eval/queryeval/main.py @@ -20,10 +20,19 @@ from sycamore.llms import MODELS from queryeval.driver import QueryEvalDriver +import nltk console = Console() +try: + nltk.data.find("tokenizers/punkt_tab") + console.print("The 'punkt_tab' tokenizer data is already downloaded.") +except LookupError: + console.print("The 'punkt_tab' tokenizer data is not found. Downloading now...") + nltk.download("punkt_tab") + console.print("The 'punkt_tab' tokenizer data has been downloaded.") + @click.group() @click.argument("config-file", type=click.Path(exists=True)) @@ -38,7 +47,10 @@ @click.option("--llm", help="LLM model name", type=click.Choice(list(MODELS.keys()))) @click.option("--tags", help="Filter queries by the given tags", multiple=True) @click.option( - "--raw-output", help="Output should be a raw DocSet, rather than natural language", is_flag=True, default=False + "--raw-output", + help="Output should be a raw DocSet, rather than natural language", + is_flag=True, + default=False, ) @click.pass_context def cli( diff --git a/apps/query-eval/queryeval/queryeval_types.py b/apps/query-eval/queryeval/queryeval_types.py index 9657dac88..23b41c724 100644 --- a/apps/query-eval/queryeval/queryeval_types.py +++ b/apps/query-eval/queryeval/queryeval_types.py @@ -64,7 +64,8 @@ class QueryEvalMetrics(BaseModel): query_time: Optional[float] = None # String answer metrics - correctness_score: Optional[float] = None + bleu_score: Optional[float] = None + rouge_score: Optional[float] = None similarity_score: Optional[float] = None