Skip to content

Commit

Permalink
Refactor Nootropic wrapper classes and add cache disabling
Browse files Browse the repository at this point in the history
- Add disable_cache option to BaseWrapper and Nootropic
- Implement shallow copy of messages to avoid modifying originals
- Refactor system message handling for consistency across wrappers
- Add __getattr__ methods to raise NotImplementedError for unimplemented methods
- Update Nootropic class to use a generic _create_wrapper method
- Remove fallback to original client in Nootropic.__getattr__
- Update version to 1.2410.1
  • Loading branch information
brunneis committed Oct 8, 2024
1 parent caaa393 commit 6e9d120
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 46 deletions.
2 changes: 1 addition & 1 deletion nootropic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

from .nootropic import *

__version__ = '1.2410.0'
__version__ = '1.2410.1'
154 changes: 109 additions & 45 deletions nootropic/nootropic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Dict, List, Optional, Callable
from typing import Any, Dict, List, Optional
import uuid
import copy


class BaseWrapper:
Expand All @@ -11,29 +13,71 @@ def __init__(
prefix: Optional[str] = None,
postfix: Optional[str] = None,
system: Optional[str] = None,
disable_cache: bool = False,
):
self._client_attr = client_attr
self._prefix = prefix
self._postfix = postfix
self._system = system
self._disable_cache = disable_cache

def _modify_messages(self, messages: List[Dict[str, str]]) -> None:
for message in messages:
def _shallow_copy_messages(
self,
messages: List[Dict[str, str]],
) -> List[Dict[str, str]]:
copied_messages = list(messages)

for i, message in enumerate(copied_messages):
copied_messages[i] = dict(message)

return copied_messages

def _modify_messages(
self,
messages: List[Dict[str, str]],
) -> List[Dict[str, str]]:
modified_messages = self._shallow_copy_messages(messages)
for message in modified_messages:
if message['role'] == 'user':
message['content'] = self._modify_content(message['content'])

def _handle_system_prompt(self, kwargs: Dict[str, Any]) -> None:
if self._system:
kwargs['system'] = self._system
return modified_messages

def _modify_content(self, content: str) -> str:
if self._disable_cache:
prompt_uid = str(uuid.uuid4())
content = '{0} {1}'.format(
'<ignore>{0}</ignore>\n\n'.format(prompt_uid),
content,
)
if self._prefix:
content = '{0} {1}'.format(self._prefix, content)
if self._postfix:
content = '{0} {1}'.format(content, self._postfix)

return content

def _add_system_message(
self,
messages: List[Dict[str, str]],
) -> List[Dict[str, str]]:
modified_messages = copy.deepcopy(messages)
system_message = {
'role': 'system',
'content': self._system,
}
if modified_messages and modified_messages[0]['role'] == 'system':
modified_messages[0] = system_message
else:
modified_messages.insert(0, system_message)
return modified_messages

def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
'Method {0} is not implemented in BaseWrapper'.format(name),
)


# llmdk
class GenerateWrapper(BaseWrapper):
def __call__(
self,
Expand All @@ -43,8 +87,7 @@ def __call__(
**kwargs: Any,
) -> str:
if messages:
self._modify_messages(messages)
kwargs['messages'] = messages
kwargs['messages'] = self._modify_messages(messages)
elif prompt:
kwargs['prompt'] = self._modify_content(prompt)

Expand All @@ -53,14 +96,21 @@ def __call__(

return self._client_attr(**kwargs)

def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
'Method {0} is not implemented in GenerateWrapper'.format(name),
)


# OpenAI, HuggingFace, Groq, Ollama
class ChatWrapper(BaseWrapper):
def __call__(self, **kwargs: Any) -> Any:
if 'messages' in kwargs:
self._modify_messages(kwargs['messages'])
self._handle_system_prompt(kwargs)

return self._client_attr.chat(**kwargs)
kwargs['messages'] = self._modify_messages(kwargs['messages'])
if self._system:
messages = kwargs.get('messages', [])
kwargs['messages'] = self._add_system_message(messages)
return self._client_attr(**kwargs)

@property
def completions(self) -> 'CompletionsWrapper':
Expand All @@ -69,39 +119,53 @@ def completions(self) -> 'CompletionsWrapper':
self._prefix,
self._postfix,
self._system,
self._disable_cache,
)

def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
'Method {0} is not implemented in ChatWrapper'.format(name),
)


# OpenAI, HuggingFace, Groq
class CompletionsWrapper(BaseWrapper):
def create(self, **kwargs: Any) -> Any:
if 'messages' in kwargs:
self._modify_messages(kwargs['messages'])
kwargs['messages'] = self._modify_messages(kwargs['messages'])
if self._system:
kwargs['messages'] = [
{
'role': 'system',
'content': self._system,
},
] + kwargs.get('messages', [])
messages = kwargs.get('messages', [])
kwargs['messages'] = self._add_system_message(messages)

return self._client_attr.create(**kwargs)

def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
'Method {0} is not implemented in CompletionsWrapper'.format(name),
)


# Anthropic
class MessagesWrapper(BaseWrapper):
def create(self, **kwargs: Any) -> Any:
def _prepare_kwargs(self, kwargs: Dict[str, Any]) -> None:
if 'messages' in kwargs:
self._modify_messages(kwargs['messages'])
self._handle_system_prompt(kwargs)
kwargs['messages'] = self._modify_messages(kwargs['messages'])
if self._system:
kwargs['system'] = self._system

def create(self, **kwargs: Any) -> Any:
self._prepare_kwargs(kwargs)
return self._client_attr.create(**kwargs)

def stream(self, **kwargs: Any) -> Any:
if 'messages' in kwargs:
self._modify_messages(kwargs['messages'])
self._handle_system_prompt(kwargs)

self._prepare_kwargs(kwargs)
return self._client_attr.stream(**kwargs)

def __getattr__(self, name: str) -> Any:
raise NotImplementedError(
'Method {0} is not implemented in MessagesWrapper'.format(name),
)


class Nootropic:
def __init__(
Expand All @@ -110,39 +174,39 @@ def __init__(
prefix: Optional[str] = None,
postfix: Optional[str] = None,
system: Optional[str] = None,
disable_cache: bool = False,
):
self._client = client
self._prefix = prefix
self._postfix = postfix
self._system = system
self._disable_cache = disable_cache

# Fallback to the original client
def __getattr__(self, name: str) -> Any:
return getattr(self._client, name)
raise NotImplementedError(
'Method {0} is not implemented in Nootropic'.format(name),
)

@property
def generate(self) -> Callable:
return GenerateWrapper(
self._client.generate,
def _create_wrapper(self, attr_name: str, wrapper_class: type) -> Any:
return wrapper_class(
getattr(self._client, attr_name, self._client),
self._prefix,
self._postfix,
self._system,
self._disable_cache,
)

# LLMDK
@property
def generate(self) -> GenerateWrapper:
return self._create_wrapper('generate', GenerateWrapper)

# OpenAI, HuggingFace, Groq, Ollama
@property
def chat(self) -> ChatWrapper:
return ChatWrapper(
self._client.chat,
self._prefix,
self._postfix,
self._system,
)
return self._create_wrapper('chat', ChatWrapper)

# Anthropic
@property
def messages(self) -> MessagesWrapper:
return MessagesWrapper(
self._client.messages,
self._prefix,
self._postfix,
self._system,
)
return self._create_wrapper('messages', MessagesWrapper)

0 comments on commit 6e9d120

Please sign in to comment.