Skip to content

Commit

Permalink
fix: removed dependency on gitpython
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 24, 2024
1 parent 1748e4d commit dd2a128
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 77 deletions.
3 changes: 2 additions & 1 deletion gptme/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ def prompt_timeinfo() -> Generator[Message, None, None]:
def get_workspace_prompt(workspace: Path) -> str:
# NOTE: needs to run after the workspace is initialized (i.e. initial prompt is constructed)
# TODO: update this prompt if the files change
# TODO: include `git status/diff/log` summary, and keep it up-to-date
# TODO: include workspace structure from gptme.util.cli.prompts_workspace
# TODO: include git summary from gptme.util.cli.prompts_git
if project := get_project_config(workspace):
files = []
for file in project.files:
Expand Down
170 changes: 94 additions & 76 deletions gptme/util/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import glob
import logging
import os
import subprocess
import sys
from datetime import datetime, timedelta
from pathlib import Path

import click
import git
from rich import print as rich_print
from rich.filesize import decimal
from rich.markup import escape
Expand All @@ -23,6 +23,42 @@
from ..tools.chats import list_chats
from . import console


def run_git(cmd: list[str], check: bool = True, timeout: int = 10) -> tuple[str, bool]:
"""Run a git command and return its output and success status."""
try:
env = os.environ.copy()
env.update(
{
"PAGER": "cat",
"GIT_PAGER": "cat",
"GIT_TERMINAL_PROMPT": "0", # Disable git's terminal prompts
}
)
logger.debug(f"Running git command: {cmd}")
result = subprocess.run(
["git"] + cmd,
capture_output=True,
text=True,
check=check,
env=env,
timeout=timeout,
)
if result.stderr:
logger.debug(f"Git stderr: {result.stderr}")
if result.stdout:
logger.debug(f"Git stdout: {result.stdout}")
return result.stdout.strip(), True
except subprocess.TimeoutExpired:
logger.error(f"Git command timed out after {timeout}s: git {' '.join(cmd)}")
return "", False
except subprocess.CalledProcessError as e:
if check:
logger.error(f"Git command failed: {e}")
logger.error(f"Git stderr: {e.stderr}")
return e.stderr.strip(), False


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -172,105 +208,87 @@ def format_section(title: str, items: list[str]) -> list[str]:
result.append(f"{item_prefix}{item}")
return result

def get_diff_output(repo: git.Repo, staged: bool = False) -> str:
"""Get formatted diff output."""
if staged:
diff_obj = repo.index.diff(repo.head.commit)
diff_str = repo.git.diff("--cached")
else:
diff_obj = repo.index.diff(None)
diff_str = repo.git.diff()

if not diff_str.strip():
return ""

[d.a_path for d in diff_obj]
title = "Staged changes" if staged else "Unstaged changes"

output = [f"### {title}", "\n```diff"]

output.extend([diff_str, "```"])
return "\n".join(output)

try:
repo = git.Repo(".", search_parent_directories=True)
except git.InvalidGitRepositoryError:
# Check if we're in a git repo
logger.debug("Checking if in git repo...")
output, success = run_git(["rev-parse", "--git-dir"])
logger.debug(f"Git repo check result: {success=}, {output=}")
if not success:
logger.error("Not a git repository")
return

sections = []

# Basic repo info
if repo.remotes:
origin = repo.remotes.origin
sections.extend([f"Repository: {origin.url}"])

try:
branch_name = repo.active_branch.name
sections.append(f"Current branch: {branch_name}")
except TypeError:
# Handle detached HEAD state
sections.append(f"HEAD is detached at {repo.head.commit.hexsha[:7]}")
branch_name = None
remote_url, success = run_git(["config", "--get", "remote.origin.url"])
if success and remote_url:
sections.extend([f"Repository: {remote_url}"])

# Get current branch
branch_name, success = run_git(["rev-parse", "--abbrev-ref", "HEAD"])
if success:
if branch_name == "HEAD":
# We're in detached HEAD state
commit_hash, _ = run_git(["rev-parse", "--short", "HEAD"])
sections.append(f"HEAD is detached at {commit_hash}")
branch_name = ""
else:
sections.append(f"Current branch: {branch_name}")

# Recent commits
commits = list(repo.iter_commits(branch or branch_name, max_count=5))
if commits:
commit_items = []
for commit in commits:
date = commit.committed_datetime.strftime("%Y-%m-%d %H:%M")
commit_items.append(f"{commit.hexsha[:7]} ({date}) {commit.summary}")
log_format = "--pretty=format:%h (%ad) %s"
commits_output, success = run_git(
[
"log",
log_format,
"--date=format:%Y-%m-%d %H:%M",
"-n",
"5",
branch or branch_name or "HEAD",
]
)
if success and commits_output:
commit_items = commits_output.split("\n")
sections.extend(format_section("Recent commits", commit_items))

# Changed files
try:
if since:
diff = repo.git.diff(since, name_only=True).split("\n")
else:
diff = [item.a_path for item in repo.index.diff(None)]
if since:
changed_files, success = run_git(["diff", "--name-only", since])
else:
changed_files, success = run_git(["diff", "--name-only"])

if diff and diff[0]:
shown_files = diff[:max_files]
if success and changed_files:
files = changed_files.split("\n")
if files and files[0]:
shown_files = files[:max_files]
sections.extend(format_section("Changed files", shown_files))
if len(diff) > max_files:
sections.append(f"... and {len(diff) - max_files} more changed files")
except git.GitCommandError as e:
logger.error(f"Error getting changed files: {e}")
if len(files) > max_files:
sections.append(f"... and {len(files) - max_files} more changed files")

# Untracked files
try:
untracked = repo.untracked_files
if untracked:
shown_files = untracked[:max_files]
untracked_files, success = run_git(["ls-files", "--others", "--exclude-standard"])
if success and untracked_files:
files = untracked_files.split("\n")
if files and files[0]:
shown_files = files[:max_files]
sections.extend(format_section("Untracked files", shown_files))
if len(untracked) > max_files:
if len(files) > max_files:
sections.append(
f"... and {len(untracked) - max_files} more untracked files"
f"... and {len(files) - max_files} more untracked files"
)
except Exception as e:
logger.error(f"Error getting untracked files: {e}")

# Add stats
try:
stats = repo.git.shortlog("--summary", "--numbered", "--email").split("\n")
if stats and stats[0]:
sections.extend(["\nContributors:"] + [f"• {s.strip()}" for s in stats[:3]])
if len(stats) > 3:
sections.append(f"... and {len(stats) - 3} more contributors")
except git.GitCommandError:
pass # Skip if stats unavailable

# Add diffs if requested
if show_diff:
# Add staged changes
staged_diff = get_diff_output(repo, staged=True)
if staged_diff:
sections.extend(["", staged_diff])
staged_diff, success = run_git(["diff", "--cached"])
if success and staged_diff:
sections.extend(["\n### Staged changes", "\n```diff", staged_diff, "```"])

# Add unstaged changes
unstaged_diff = get_diff_output(repo, staged=False)
if unstaged_diff:
sections.extend(["", unstaged_diff])
unstaged_diff, success = run_git(["diff"])
if success and unstaged_diff:
sections.extend(
["\n### Unstaged changes", "\n```diff", unstaged_diff, "```"]
)

print("\n".join(sections))

Expand Down

0 comments on commit dd2a128

Please sign in to comment.