Skip to content

Commit

Permalink
Changes calling with ollama so it preserves chat history
Browse files Browse the repository at this point in the history
  • Loading branch information
fccoelho committed May 28, 2024
1 parent f489df6 commit bfe29b4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
75 changes: 57 additions & 18 deletions base_agent/llminterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,56 @@
import ollama
import dotenv
import os
from collections import deque
from typing import List, Dict

dotenv.load_dotenv()

class ChatHistory:
"""
The ChatHistory class is a FIFO queue that keeps track of chat history.
Attributes:
queue (collections.deque): A deque object that stores the chat history.
"""

def __init__(self, max_size=1000):
"""
Initialize the ChatHistory class with a maximum size.
Args:
max_size (int): The maximum size of the queue. Defaults to 1000.
"""
self.queue = deque(maxlen=max_size)

def enqueue(self, item):
"""
Add a message to the end of the queue.
Args:
item: The message to be added to the queue.
"""
self.queue.append(item)

def dequeue(self):
"""
Remove and return a message from the front of the queue.
Returns:
The message removed from the front of the queue. If the queue is empty, returns None.
"""
if len(self.queue) == 0:
return None
return self.queue.popleft()

def get_all(self):
"""
Return all items in the queue as a list without removing them from the queue.
Returns:
list: A list of all items in the queue.
"""
return list(self.queue)
class LangModel:
"""
Interface to interact with language models
Expand All @@ -23,6 +68,7 @@ def __init__(self, model: str = 'gpt-4o'):

self.available_models: List = ollama.list()['models']
self.model = "llama3"
self.chat_history = ChatHistory()
self._set_active_model(model)

def _set_active_model(self, model: str):
Expand All @@ -38,9 +84,6 @@ def _set_active_model(self, model: str):
def get_response(self, question: str, context: str = None) -> str:
if 'gpt' in self.model:
return self.get_gpt_response(question, context)
elif self.model == 'gemma':
self.model = 'gemma'
return self.get_gemma_response(question, context)
else:
return self.get_ollama_response(question, context)

Expand All @@ -64,28 +107,24 @@ def get_gpt_response(self, question: str, context: str) -> str:
)
return response.choices[0].message.content

def get_gemma_response(self, question: str, context: str) -> str:
response = ollama.generate(
model=self.model,
system=context,
prompt=question,
)

return response['response']
# return '/n'.join([resp['response'] for resp in response ])

def get_ollama_response(self, question: str, context: str) -> str:
"""
Get response from any Ollama supported model
:param question: question to ask
:param context: context to provide
:return: model's response
"""
response = self.llm.generate(
msg = {
'role': 'user',
'content': context + '\n\n' + question
}
self.chat_history.enqueue(msg)
messages = self.chat_history.get_all()
response = self.llm.chat(
model=self.model,
system=context,
prompt=question,
messages=messages,
options={'temperature': 0}
)
self.chat_history.enqueue(response['message'])

return response['response']
# return '/n'.join([resp['response'] for resp in response ])
return response['message']['content']
2 changes: 1 addition & 1 deletion tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_init_with_unsupported_model(self, mock_ollama, mock_client, mock_openai

@patch('base_agent.llminterface.LangModel.get_gpt_response')
def test_get_response_with_gpt_model(self, mock_get_gpt_response):
lm = LangModel('gpt-4-turbo')
lm = LangModel('gpt-4o')
lm.get_response('question', 'context')
mock_get_gpt_response.assert_called_once_with('question', 'context')

Expand Down

0 comments on commit bfe29b4

Please sign in to comment.