Skip to content

Commit

Permalink
feat(js): openai completions (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Jan 11, 2024
1 parent 54bc6d0 commit 1925aad
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 23 deletions.
5 changes: 5 additions & 0 deletions js/.changeset/purple-cherries-judge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@arizeai/openinference-instrumentation-openai": minor
---

add support for legacy completions api
5 changes: 5 additions & 0 deletions js/.changeset/swift-grapes-joke.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@arizeai/openinference-semantic-conventions": minor
---

add llm.prompts semantic convention
1 change: 1 addition & 0 deletions js/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ module.exports = {
varsIgnorePattern: "^_",
},
], // ignore unused variables starting with underscore
eqeqeq: ["error", "always"],
},
};
175 changes: 154 additions & 21 deletions js/packages/openinference-instrumentation-openai/src/instrumentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import {
ChatCompletionChunk,
ChatCompletionCreateParamsBase,
} from "openai/resources/chat/completions";
import { CompletionCreateParamsBase } from "openai/resources/completions";
import { Stream } from "openai/streaming";
import {
Completion,
CreateEmbeddingResponse,
EmbeddingCreateParams,
} from "openai/resources";
Expand Down Expand Up @@ -65,20 +67,19 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
}
// eslint-disable-next-line @typescript-eslint/no-this-alias
const instrumentation: OpenAIInstrumentation = this;
type CompletionCreateType =

type ChatCompletionCreateType =
typeof module.OpenAI.Chat.Completions.prototype.create;

// Patch create chat completions
this._wrap(
module.OpenAI.Chat.Completions.prototype,
"create",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(original: CompletionCreateType): any => {
(original: ChatCompletionCreateType): any => {
return function patchedCreate(
this: unknown,
...args: Parameters<
typeof module.OpenAI.Chat.Completions.prototype.create
>
...args: Parameters<ChatCompletionCreateType>
) {
const body = args[0];
const { messages: _messages, ...invocationParameters } = body;
Expand All @@ -100,7 +101,7 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
);
const execContext = trace.setSpan(context.active(), span);
const execPromise = safeExecuteInTheMiddle<
ReturnType<CompletionCreateType>
ReturnType<ChatCompletionCreateType>
>(
() => {
return context.with(execContext, () => {
Expand All @@ -127,7 +128,7 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON,
// Override the model from the value sent by the server
[SemanticConventions.LLM_MODEL_NAME]: result.model,
...getLLMOutputMessagesAttributes(result),
...getChatCompletionLLMOutputMessagesAttributes(result),
...getUsageAttributes(result),
});
span.setStatus({ code: SpanStatusCode.OK });
Expand All @@ -148,6 +149,74 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
},
);

// Patch create completions
type CompletionsCreateType =
typeof module.OpenAI.Completions.prototype.create;

this._wrap(
module.OpenAI.Completions.prototype,
"create",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(original: CompletionsCreateType): any => {
return function patchedCreate(
this: unknown,
...args: Parameters<CompletionsCreateType>
) {
const body = args[0];
const { prompt: _prompt, ...invocationParameters } = body;
const span = instrumentation.tracer.startSpan(`OpenAI Completions`, {
kind: SpanKind.INTERNAL,
attributes: {
[SemanticConventions.OPENINFERENCE_SPAN_KIND]:
OpenInferenceSpanKind.LLM,
[SemanticConventions.LLM_MODEL_NAME]: body.model,
[SemanticConventions.LLM_INVOCATION_PARAMETERS]:
JSON.stringify(invocationParameters),
...getCompletionInputValueAndMimeType(body),
},
});
const execContext = trace.setSpan(context.active(), span);
const execPromise = safeExecuteInTheMiddle<
ReturnType<CompletionsCreateType>
>(
() => {
return context.with(execContext, () => {
return original.apply(this, args);
});
},
(error) => {
// Push the error to the span
if (error) {
span.recordException(error);
span.setStatus({
code: SpanStatusCode.ERROR,
message: error.message,
});
span.end();
}
},
);
const wrappedPromise = execPromise.then((result) => {
if (isCompletionResponse(result)) {
// Record the results
span.setAttributes({
[SemanticConventions.OUTPUT_VALUE]: JSON.stringify(result),
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.JSON,
// Override the model from the value sent by the server
[SemanticConventions.LLM_MODEL_NAME]: result.model,
...getCompletionOutputValueAndMimeType(result),
...getUsageAttributes(result),
});
span.setStatus({ code: SpanStatusCode.OK });
span.end();
}
return result;
});
return context.bind(execContext, wrappedPromise);
};
},
);

// Patch embeddings
type EmbeddingsCreateType =
typeof module.OpenAI.Embeddings.prototype.create;
Expand All @@ -158,11 +227,11 @@ export class OpenAIInstrumentation extends InstrumentationBase<typeof openai> {
(original: EmbeddingsCreateType): any => {
return function patchedEmbeddingCreate(
this: unknown,
...args: Parameters<typeof module.OpenAI.Embeddings.prototype.create>
...args: Parameters<EmbeddingsCreateType>
) {
const body = args[0];
const { input } = body;
const isStringInput = typeof input == "string";
const isStringInput = typeof input === "string";
const span = instrumentation.tracer.startSpan(`OpenAI Embeddings`, {
kind: SpanKind.INTERNAL,
attributes: {
Expand Down Expand Up @@ -239,7 +308,27 @@ function isChatCompletionResponse(
}

/**
* Converts the body of the request to LLM input messages
* type-guard that checks if the response is a completion response
*/
function isCompletionResponse(
response: Stream<Completion> | Completion,
): response is Completion {
return "choices" in response;
}

/**
* type-guard that checks if completion prompt attribute is an array of strings
*/
function isPromptStringArray(
prompt: CompletionCreateParamsBase["prompt"],
): prompt is Array<string> {
return (
Array.isArray(prompt) && prompt.every((item) => typeof item === "string")
);
}

/**
* Converts the body of a chat completions request to LLM input messages
*/
function getLLMInputMessagesAttributes(
body: ChatCompletionCreateParamsBase,
Expand All @@ -257,9 +346,36 @@ function getLLMInputMessagesAttributes(
}

/**
* Get Usage attributes
* Converts the body of a completions request to input attributes
*/
function getCompletionInputValueAndMimeType(
body: CompletionCreateParamsBase,
): Attributes {
if (typeof body.prompt === "string") {
return {
[SemanticConventions.INPUT_VALUE]: body.prompt,
[SemanticConventions.INPUT_MIME_TYPE]: MimeType.TEXT,
};
} else if (isPromptStringArray(body.prompt)) {
const prompt = body.prompt[0]; // Only single prompts are currently supported
if (prompt === undefined) {
return {};
}
return {
[SemanticConventions.INPUT_VALUE]: prompt,
[SemanticConventions.INPUT_MIME_TYPE]: MimeType.TEXT,
};
}
// Other cases in which the prompt is a token or array of tokens are currently unsupported
return {};
}

/**
* Get usage attributes
*/
function getUsageAttributes(completion: ChatCompletion): Attributes {
function getUsageAttributes(
completion: ChatCompletion | Completion,
): Attributes {
if (completion.usage) {
return {
[SemanticConventions.LLM_TOKEN_COUNT_COMPLETION]:
Expand All @@ -274,41 +390,58 @@ function getUsageAttributes(completion: ChatCompletion): Attributes {
}

/**
* Converts the result to LLM output attributes
* Converts the chat completion result to LLM output attributes
*/
function getLLMOutputMessagesAttributes(
completion: ChatCompletion,
function getChatCompletionLLMOutputMessagesAttributes(
chatCompletion: ChatCompletion,
): Attributes {
// Right now support just the first choice
const choice = completion.choices[0];
const choice = chatCompletion.choices[0];
if (!choice) {
return {};
}
return [choice.message].reduce((acc, message, index) => {
const index_prefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`;
acc[`${index_prefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
const indexPrefix = `${SemanticConventions.LLM_OUTPUT_MESSAGES}.${index}`;
acc[`${indexPrefix}.${SemanticConventions.MESSAGE_CONTENT}`] = String(
message.content,
);
acc[`${index_prefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role;
acc[`${indexPrefix}.${SemanticConventions.MESSAGE_ROLE}`] = message.role;
return acc;
}, {} as Attributes);
}

/**
* Converts the completion result to output attributes
*/
function getCompletionOutputValueAndMimeType(
completion: Completion,
): Attributes {
// Right now support just the first choice
const choice = completion.choices[0];
if (!choice) {
return {};
}
return {
[SemanticConventions.OUTPUT_VALUE]: String(choice.text),
[SemanticConventions.OUTPUT_MIME_TYPE]: MimeType.TEXT,
};
}

/**
* Converts the embedding result payload to embedding attributes
*/
function getEmbeddingTextAttributes(
request: EmbeddingCreateParams,
): Attributes {
if (typeof request.input == "string") {
if (typeof request.input === "string") {
return {
[`${SemanticConventions.EMBEDDING_EMBEDDINGS}.0.${SemanticConventions.EMBEDDING_TEXT}`]:
request.input,
};
} else if (
Array.isArray(request.input) &&
request.input.length > 0 &&
typeof request.input[0] == "string"
typeof request.input[0] === "string"
) {
return request.input.reduce((acc, input, index) => {
const index_prefix = `${SemanticConventions.EMBEDDING_EMBEDDINGS}.${index}`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,52 @@ describe("OpenAIInstrumentation", () => {
}
`);
});
it("creates a span for completions", async () => {
const response = {
id: "cmpl-8fZu1H3VijJUWev9asnxaYyQvJTC9",
object: "text_completion",
created: 1704920149,
model: "gpt-3.5-turbo-instruct",
choices: [
{
text: "This is a test",
index: 0,
logprobs: null,
finish_reason: "stop",
},
],
usage: { prompt_tokens: 12, completion_tokens: 5, total_tokens: 17 },
};
// Mock out the completions endpoint
jest.spyOn(openai, "post").mockImplementation(
// @ts-expect-error the response type is not correct - this is just for testing
async (): Promise<unknown> => {
return response;
},
);
await openai.completions.create({
prompt: "Say this is a test",
model: "gpt-3.5-turbo-instruct",
});
const spans = memoryExporter.getFinishedSpans();
expect(spans.length).toBe(1);
const span = spans[0];
expect(span.name).toBe("OpenAI Completions");
expect(span.attributes).toMatchInlineSnapshot(`
{
"input.mime_type": "text/plain",
"input.value": "Say this is a test",
"llm.invocation_parameters": "{"model":"gpt-3.5-turbo-instruct"}",
"llm.model_name": "gpt-3.5-turbo-instruct",
"llm.token_count.completion": 5,
"llm.token_count.prompt": 12,
"llm.token_count.total": 17,
"openinference.span.kind": "llm",
"output.mime_type": "text/plain",
"output.value": "This is a test",
}
`);
});
it("creates a span for embedding create", async () => {
const response = {
object: "list",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,20 @@ export const OUTPUT_MIME_TYPE =
`${SemanticAttributePrefixes.output}.mime_type` as const;
/**
* The messages sent to the LLM for completions
* Typically seen in openAI chat completions
* Typically seen in OpenAI chat completions
* @see https://beta.openai.com/docs/api-reference/completions/create
*/
export const LLM_INPUT_MESSAGES =
`${SemanticAttributePrefixes.llm}.${LLMAttributePostfixes.input_messages}` as const;

/**
* The prompts sent to the LLM for completions
* Typically seen in OpenAI legacy completions
* @see https://beta.openai.com/docs/api-reference/completions/create
*/
export const LLM_PROMPTS =
`${SemanticAttributePrefixes.llm}.${LLMAttributePostfixes.prompts}` as const;

/**
* The JSON representation of the parameters passed to the LLM
*/
Expand All @@ -107,7 +115,7 @@ export const LLM_INVOCATION_PARAMETERS =

/**
* The messages received from the LLM for completions
* Typically seen in openAI chat completions
* Typically seen in OpenAI chat completions
* @see https://platform.openai.com/docs/api-reference/chat/object#choices-message
*/
export const LLM_OUTPUT_MESSAGES =
Expand Down Expand Up @@ -224,6 +232,7 @@ export const SemanticConventions = {
LLM_INPUT_MESSAGES,
LLM_OUTPUT_MESSAGES,
LLM_MODEL_NAME,
LLM_PROMPTS,
LLM_INVOCATION_PARAMETERS,
LLM_TOKEN_COUNT_COMPLETION,
LLM_TOKEN_COUNT_PROMPT,
Expand Down

0 comments on commit 1925aad

Please sign in to comment.