Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions py/src/braintrust/integrations/cohere/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from braintrust.integrations.base import BaseIntegration

from .patchers import (
CohereAudioTranscriptionsPatcher,
CohereChatPatcher,
CohereChatStreamPatcher,
CohereEmbedPatcher,
Expand All @@ -21,4 +22,5 @@ class CohereIntegration(BaseIntegration):
CohereChatStreamPatcher,
CohereEmbedPatcher,
CohereRerankPatcher,
CohereAudioTranscriptionsPatcher,
)
32 changes: 32 additions & 0 deletions py/src/braintrust/integrations/cohere/patchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from braintrust.integrations.base import CompositeFunctionWrapperPatcher, FunctionWrapperPatcher

from .tracing import (
_async_audio_transcription_wrapper,
_async_chat_stream_wrapper,
_async_chat_wrapper,
_async_embed_wrapper,
_async_rerank_wrapper,
_audio_transcription_wrapper,
_chat_stream_wrapper,
_chat_wrapper,
_embed_wrapper,
Expand Down Expand Up @@ -158,6 +160,28 @@ class AsyncV2RerankPatcher(FunctionWrapperPatcher):
wrapper = _async_rerank_wrapper


# ---------------------------------------------------------------------------
# Audio — transcriptions (added in cohere==6.1.0)
#
# The transcription surface lives in its own module and is exposed on v1
# clients (``client.audio.transcriptions.create``) but not on v2 clients.
# ---------------------------------------------------------------------------


class TranscriptionsCreatePatcher(FunctionWrapperPatcher):
name = "cohere.audio.transcriptions.create"
target_module = "cohere.audio.transcriptions.client"
target_path = "TranscriptionsClient.create"
wrapper = _audio_transcription_wrapper


class AsyncTranscriptionsCreatePatcher(FunctionWrapperPatcher):
name = "cohere.audio.transcriptions.async.create"
target_module = "cohere.audio.transcriptions.client"
target_path = "AsyncTranscriptionsClient.create"
wrapper = _async_audio_transcription_wrapper


# ---------------------------------------------------------------------------
# Composite patchers — group all sync/async variants by execution surface.
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -201,3 +225,11 @@ class CohereRerankPatcher(CompositeFunctionWrapperPatcher):
V2RerankPatcher,
AsyncV2RerankPatcher,
)


class CohereAudioTranscriptionsPatcher(CompositeFunctionWrapperPatcher):
name = "cohere.audio.transcriptions.all"
sub_patchers = (
TranscriptionsCreatePatcher,
AsyncTranscriptionsCreatePatcher,
)
161 changes: 158 additions & 3 deletions py/src/braintrust/integrations/cohere/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import inspect
import os
import time
from pathlib import Path

import pytest
from braintrust import logger
from braintrust import Attachment, logger
from braintrust.integrations.cohere import CohereIntegration, wrap_cohere
from braintrust.integrations.cohere.patchers import (
AsyncChatPatcher,
AsyncChatStreamPatcher,
AsyncEmbedPatcher,
AsyncRerankPatcher,
AsyncTranscriptionsCreatePatcher,
AsyncV2ChatPatcher,
AsyncV2ChatStreamPatcher,
AsyncV2EmbedPatcher,
Expand All @@ -21,6 +23,7 @@
ChatStreamPatcher,
EmbedPatcher,
RerankPatcher,
TranscriptionsCreatePatcher,
V2ChatPatcher,
V2ChatStreamPatcher,
V2EmbedPatcher,
Expand All @@ -39,6 +42,8 @@
CHAT_MODEL = "command-a-03-2025"
EMBED_MODEL = "embed-english-v3.0"
RERANK_MODEL = "rerank-english-v3.0"
TRANSCRIBE_MODEL = "cohere-transcribe-03-2026"
TEST_AUDIO_FILE = Path(__file__).resolve().parents[2] / "fixtures" / "test_audio.wav"
COHERE_API_KEY = os.getenv("CO_API_KEY") or os.getenv("COHERE_API_KEY") or "co-test-dummy-api-key-for-vcr-tests"


Expand Down Expand Up @@ -104,6 +109,15 @@ def clean_cohere_methods():
if hasattr(cls, attr):
targets.append((cls, attr))

try:
from cohere.audio.transcriptions.client import AsyncTranscriptionsClient, TranscriptionsClient
except ImportError:
pass
else:
for cls in (TranscriptionsClient, AsyncTranscriptionsClient):
if hasattr(cls, "create"):
targets.append((cls, "create"))

originals = [(cls, attr, inspect.getattr_static(cls, attr)) for cls, attr in targets]
# Also capture patch markers so we can clear them.
marker_attrs = set()
Expand All @@ -124,6 +138,8 @@ def clean_cohere_methods():
AsyncV2ChatStreamPatcher,
AsyncV2EmbedPatcher,
AsyncV2RerankPatcher,
TranscriptionsCreatePatcher,
AsyncTranscriptionsCreatePatcher,
):
marker_attrs.add(patcher.patch_marker_attr())

Expand All @@ -132,14 +148,22 @@ def clean_cohere_methods():
finally:
for cls, attr, original in originals:
setattr(cls, attr, original)
# Clear any patch markers that setup() may have added.
for cls, _, _ in originals:
# Clear any patch markers that setup() may have added. wrapt can forward
# setattr from a FunctionWrapper onto the wrapped function, so the
# restored original may still carry the marker; clear it from both
# class and restored function.
for cls, _, original in originals:
for marker in marker_attrs:
if hasattr(cls, marker):
try:
delattr(cls, marker)
except AttributeError:
pass
if hasattr(original, marker):
try:
delattr(original, marker)
except AttributeError:
pass


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -171,9 +195,32 @@ def test_cohere_integration_available_patchers_ids():
"cohere.chat_stream.all",
"cohere.embed.all",
"cohere.rerank.all",
"cohere.audio.transcriptions.all",
}


def test_audio_transcriptions_patchers_target_sdk_surface():
"""The audio transcription patchers must point at the Cohere SDK classes.

Regression guard for https://github.com/braintrustdata/braintrust-sdk-python/issues/327:
we must instrument both ``TranscriptionsClient.create`` and
``AsyncTranscriptionsClient.create`` on the ``cohere.audio.transcriptions``
surface introduced in cohere>=6.1.0.
"""
try:
import cohere.audio.transcriptions.client as transcriptions_module
except ImportError:
pytest.skip("cohere SDK does not expose audio.transcriptions")

assert TranscriptionsCreatePatcher.target_module == "cohere.audio.transcriptions.client"
assert TranscriptionsCreatePatcher.target_path == "TranscriptionsClient.create"
assert AsyncTranscriptionsCreatePatcher.target_module == "cohere.audio.transcriptions.client"
assert AsyncTranscriptionsCreatePatcher.target_path == "AsyncTranscriptionsClient.create"

assert hasattr(transcriptions_module.TranscriptionsClient, "create")
assert hasattr(transcriptions_module.AsyncTranscriptionsClient, "create")


# ---------------------------------------------------------------------------
# VCR-backed integration tests
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -497,6 +544,114 @@ def test_cohere_integration_setup_patches_v2_chat(memory_logger, clean_cohere_me
assert spans[0]["metadata"]["model"] == CHAT_MODEL


@pytest.mark.vcr
def test_wrap_cohere_audio_transcription_sync(memory_logger):
pytest.importorskip("cohere.audio.transcriptions.client")
assert not memory_logger.pop()

client = wrap_cohere(_v1_client())

start = time.time()
with open(TEST_AUDIO_FILE, "rb") as file_obj:
response = client.audio.transcriptions.create(
model=TRANSCRIBE_MODEL,
language="en",
file=(TEST_AUDIO_FILE.name, file_obj, "audio/wav"),
temperature=0.0,
)
end = time.time()

# Provider behavior preserved.
assert isinstance(response.text, str)
assert response.text # non-empty

spans = memory_logger.pop()
assert len(spans) == 1
span = spans[0]

assert span["span_attributes"]["name"] == "cohere.audio.transcriptions.create"
assert span["span_attributes"]["type"] == "llm"
assert span["metadata"]["provider"] == "cohere"
assert span["metadata"]["model"] == TRANSCRIBE_MODEL
assert span["metadata"]["language"] == "en"
assert span["metadata"]["temperature"] == 0.0

# Input carries the audio file as an Attachment.
file_attachment = span["input"]["file"]
assert isinstance(file_attachment, Attachment)
assert file_attachment.reference["filename"] == TEST_AUDIO_FILE.name
assert file_attachment.reference["content_type"] == "audio/wav"

# Output is the transcribed text.
assert span["output"] == response.text

# Cohere's transcription response does not expose token counts, so we
# only assert the timing metrics we always record.
metrics = span["metrics"]
assert start <= metrics["start"] <= metrics["end"] <= end
assert metrics["duration"] >= 0


@pytest.mark.vcr
def test_wrap_cohere_audio_transcription_async(memory_logger):
pytest.importorskip("cohere.audio.transcriptions.client")
assert not memory_logger.pop()

async def _run():
client = wrap_cohere(_v1_async_client())
with open(TEST_AUDIO_FILE, "rb") as file_obj:
return await client.audio.transcriptions.create(
model=TRANSCRIBE_MODEL,
language="en",
file=(TEST_AUDIO_FILE.name, file_obj, "audio/wav"),
)

response = asyncio.run(_run())
assert isinstance(response.text, str)

spans = memory_logger.pop()
assert len(spans) == 1
span = spans[0]

assert span["span_attributes"]["name"] == "cohere.audio.transcriptions.create"
assert span["metadata"]["provider"] == "cohere"
assert span["metadata"]["model"] == TRANSCRIBE_MODEL
assert span["metadata"]["language"] == "en"

file_attachment = span["input"]["file"]
assert isinstance(file_attachment, Attachment)
assert file_attachment.reference["filename"] == TEST_AUDIO_FILE.name

assert span["output"] == response.text


@pytest.mark.vcr
def test_cohere_integration_setup_patches_audio_transcriptions(memory_logger, clean_cohere_methods):
"""``CohereIntegration.setup()`` must wire up audio transcription tracing."""
pytest.importorskip("cohere.audio.transcriptions.client")
assert not memory_logger.pop()

assert CohereIntegration.setup() is True
# Second call is a no-op but still reports success.
assert CohereIntegration.setup() is True

client = _v1_client() # NOT manually wrapped
with open(TEST_AUDIO_FILE, "rb") as file_obj:
response = client.audio.transcriptions.create(
model=TRANSCRIBE_MODEL,
language="en",
file=(TEST_AUDIO_FILE.name, file_obj, "audio/wav"),
temperature=0.0,
)
assert isinstance(response.text, str)

spans = memory_logger.pop()
assert len(spans) == 1
assert spans[0]["span_attributes"]["name"] == "cohere.audio.transcriptions.create"
assert spans[0]["metadata"]["provider"] == "cohere"
assert spans[0]["metadata"]["model"] == TRANSCRIBE_MODEL


class TestAutoInstrumentCohere:
def test_auto_instrument_cohere(self):
verify_autoinstrument_script("test_auto_cohere.py")
Loading