Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to override config from project #399

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import re
import sys
import termios
from typing import cast
import urllib.parse
from collections.abc import Generator
from pathlib import Path
Expand All @@ -25,7 +24,7 @@
get_tools,
execute_msg,
ConfirmFunc,
set_tool_format,
get_tool_format,
)
from .tools.browser import read_url
from .util import console, path_with_tilde, print_bell
Expand All @@ -41,15 +40,11 @@
def chat(
prompt_msgs: list[Message],
initial_msgs: list[Message],
logdir: Path,
model: str | None,
stream: bool = True,
no_confirm: bool = False,
interactive: bool = True,
show_hidden: bool = False,
workspace: Path | None = None,
tool_allowlist: list[str] | None = None,
tool_format: ToolFormat | None = None,
) -> None:
"""
Run the chat loop.
Expand All @@ -60,8 +55,7 @@ def chat(

Callable from other modules.
"""
# init
init(model, interactive, tool_allowlist)
init(model, interactive)

if not get_model().supports_streaming and stream:
logger.info(
Expand All @@ -71,36 +65,21 @@ def chat(
)
stream = False

config = get_config()

logdir = config.get_log_dir()
console.log(f"Using logdir {path_with_tilde(logdir)}")
manager = LogManager.load(logdir, initial_msgs=initial_msgs, create=True)

config = get_config()
tool_format_with_default: ToolFormat = tool_format or cast(
ToolFormat, config.get_env("TOOL_FORMAT", "markdown")
)

# By defining the tool_format at the last moment we ensure we can use the
# configuration for subagent
set_tool_format(tool_format_with_default)
# log_workspace = logdir / "workspace"
workspace = config.get_workspace_dir()

# change to workspace directory
# use if exists, create if @log, or use given path
# TODO: move this into LogManager? then just os.chdir(manager.workspace)
log_workspace = logdir / "workspace"
if log_workspace.exists():
assert not workspace or (
workspace == log_workspace
), f"Workspace already exists in {log_workspace}, wont override."
workspace = log_workspace.resolve()
else:
if not workspace:
workspace = Path.cwd()
log_workspace.symlink_to(workspace, target_is_directory=True)
assert workspace.exists(), f"Workspace path {workspace} does not exist"
console.log(f"Using workspace at {path_with_tilde(workspace)}")
os.chdir(workspace)

workspace_prompt = get_workspace_prompt(workspace)

# FIXME: this is hacky
# NOTE: needs to run after the workspace is set
# check if message is already in log, such as upon resume
Expand Down Expand Up @@ -142,7 +121,7 @@ def confirm_func(msg) -> bool:
manager.log,
stream,
confirm_func,
tool_format=tool_format_with_default,
tool_format=get_tool_format(),
workspace=workspace,
)
)
Expand Down Expand Up @@ -194,7 +173,7 @@ def confirm_func(msg) -> bool:
manager.log,
stream,
confirm_func,
tool_format=tool_format_with_default,
tool_format=get_tool_format(),
workspace=workspace,
): # pragma: no cover
manager.append(msg)
Expand Down
84 changes: 31 additions & 53 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from datetime import datetime
from itertools import islice
from pathlib import Path
from typing import Literal

import click
from pick import pick


from .chat import chat
from .config import get_config
from .commands import _gen_help
from .constants import MULTIPROMPT_SEPARATOR
from .dirs import get_logs_dir
Expand Down Expand Up @@ -175,28 +173,6 @@ def main(
if no_confirm:
logger.warning("Skipping all confirmation prompts.")

if tool_allowlist:
# split comma-separated values
tool_allowlist = [tool for tools in tool_allowlist for tool in tools.split(",")]

config = get_config()

selected_tool_format: ToolFormat = (
tool_format or config.get_env("TOOL_FORMAT") or "markdown" # type: ignore
)

# early init tools to generate system prompt
init_tools(frozenset(tool_allowlist) if tool_allowlist else None)

# get initial system prompt
initial_msgs = [
get_prompt(
prompt_system,
interactive=interactive,
tool_format=selected_tool_format,
)
]

# if stdin is not a tty, we might be getting piped input, which we should include in the prompt
was_piped = False
piped_input = None
Expand Down Expand Up @@ -247,7 +223,7 @@ def inject_stdin(prompt_msgs, piped_input: str | None) -> list[Message]:
return prompt_msgs

if resume:
logdir = get_logdir_resume()
name = get_name_resume()
prompt_msgs = inject_stdin(prompt_msgs, piped_input)
# don't run pick in tests/non-interactive mode, or if the user specifies a name
elif (
Expand All @@ -257,17 +233,33 @@ def inject_stdin(prompt_msgs, piped_input: str | None) -> list[Message]:
and not was_piped
and sys.stdin.isatty()
):
logdir = pick_log()
name = pick_existing_session()
else:
logdir = get_logdir(name)
prompt_msgs = inject_stdin(prompt_msgs, piped_input)

if workspace == "@log":
workspace_path: Path | None = logdir / "workspace"
assert workspace_path # mypy not smart enough to see its not None
workspace_path.mkdir(parents=True, exist_ok=True)
else:
workspace_path = Path(workspace) if workspace else None
os.environ["SESSION_NAME"] = name

if tool_format is not None:
os.environ["TOOL_FORMAT"] = tool_format

if tool_allowlist:
# split comma-separated values
tool_allowlist = [tool for tools in tool_allowlist for tool in tools.split(",")]
os.environ["TOOL_ALLOWLIST"] = ",".join(tool_allowlist)

if workspace:
os.environ["WORKSPACE"] = workspace

# Early init tools to generate system prompt
init_tools()

# get initial system prompt
initial_msgs = [
get_prompt(
prompt_system,
interactive=interactive,
)
]

# register a handler for Ctrl-C
set_interruptible() # prepare, user should be able to Ctrl+C until user prompt ready
Expand All @@ -277,15 +269,11 @@ def inject_stdin(prompt_msgs, piped_input: str | None) -> list[Message]:
chat(
prompt_msgs,
initial_msgs,
logdir,
model,
stream,
no_confirm,
interactive,
show_hidden,
workspace_path,
tool_allowlist,
selected_tool_format,
)
except RuntimeError as e:
logger.error(e)
Expand Down Expand Up @@ -323,7 +311,7 @@ def get_name(name: str) -> str:
return name


def pick_log(limit=20) -> Path: # pragma: no cover
def pick_existing_session(limit=20) -> str: # pragma: no cover
# let user select between starting a new conversation and loading a previous one
# using the library
title = "New conversation or load previous? "
Expand Down Expand Up @@ -361,26 +349,16 @@ def pick_log(limit=20) -> Path: # pragma: no cover
index: int
_, index = pick(options, title) # type: ignore
if index == 0:
return get_logdir("random")
return get_name("random")
elif index == len(options) - 1:
return pick_log(limit + 100)
return pick_existing_session(limit + 100)
else:
return get_logdir(convs[index - 1].name)


def get_logdir(logdir: Path | str | Literal["random"]) -> Path:
if logdir == "random":
logdir = get_logs_dir() / get_name("random")
elif isinstance(logdir, str):
logdir = get_logs_dir() / logdir

logdir.mkdir(parents=True, exist_ok=True)
return logdir
return convs[index - 1].name


def get_logdir_resume() -> Path:
def get_name_resume() -> str:
if conv := next(get_user_conversations(), None):
return Path(conv.path).parent
return conv.name
else:
raise ValueError("No previous conversations to resume")

Expand Down
Loading
Loading