8000 Add function calling and device controls to the simple chatbot example by mattieruth · Pull Request #1807 · pipecat-ai/pipecat · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add function calling and device controls to the simple chatbot example #1807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally 8000 send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/simple-chatbot/client/javascript/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
</div>
</div>

<div class="device-bar">
<div class="device-controls">
<select id="device-selector"></select>
<button id="mic-toggle-btn">Mute Mic</button>
</div>
</div>

<div class="debug-panel">
<h3>Debug Info</h3>
<div id="debug-log"></div>
Expand Down
105 changes: 102 additions & 3 deletions examples/simple-chatbot/client/javascript/src/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,25 @@
* - Browser with WebRTC support
*/

import { RTVIClient, RTVIEvent } from '@pipecat-ai/client-js';
import { LLMHelper, RTVIClient, RTVIEvent } from '@pipecat-ai/client-js';
import { DailyTransport } from '@pipecat-ai/daily-transport';

function _generateRandomWeather() {
const temperature = Math.random() * 200 - 80;
const humidity = Math.random() * 100;
const conditions = ['sunny', 'cloudy', 'rainy', 'snowy'];
const condition = conditions[Math.floor(Math.random() * conditions.length)];
const windSpeed = Math.random() * 50;
const windGusts = windSpeed + Math.random() * 20;
return {
temperature,
humidity,
condition,
windSpeed,
windGusts,
};
}

/**
* ChatbotClient handles the connection and media management for a real-time
* voice and video interaction with an AI bot.
Expand All @@ -28,7 +44,6 @@ class ChatbotClient {
// Initialize client state
this.rtviClient = null;
this.setupDOMElements();
this.setupEventListeners();
this.initializeClientAndTransport();
}

Expand All @@ -42,6 +57,7 @@ class ChatbotClient {
this.statusSpan = document.getElementById('connection-status');
this.debugLog = document.getElementById('debug-log');
this.botVideoContainer = document.getElementById('bot-video-container');
this.deviceSelector = document.getElementById('device-selector');

// Create an audio element for bot's voice output
this.botAudio = document.createElement('audio');
Expand All @@ -56,12 +72,45 @@ class ChatbotClient {
setupEventListeners() {
this.connectBtn.addEventListener('click', () => this.connect());
this.disconnectBtn.addEventListener('click', () => this.disconnect());

// Populate device selector
this.rtviClient.getAllMics().then((mics) => {
console.log('Available mics:', mics);
mics.forEach((device) => {
const option = document.createElement('option');
option.value = device.deviceId;
option.textContent = device.label || `Microphone ${device.deviceId}`;
this.deviceSelector.appendChild(option);
});
});
this.deviceSelector.addEventListener('change', (event) => {
const selectedDeviceId = event.target.value;
console.log('Selected device ID:', selectedDeviceId);
this.rtviClient.updateMic(selectedDeviceId);
});

// Handle mic mute/unmute toggle
const micToggleBtn = document.getElementById('mic-toggle-btn');

micToggleBtn.addEventListener('click', () => {
let micEnabled = this.rtviClient.isMicEnabled;
micToggleBtn.textContent = micEnabled ? 'Unmute Mic' : 'Mute Mic';
this.rtviClient.enableMic(!micEnabled);
// Add logic to mute/unmute the mic
if (micEnabled) {
console.log('Mic muted');
// Add code to mute the mic
} else {
console.log('Mic unmuted');
// Add code to unmute the mic
}
});
}

/**
* Set up the RTVI client and Daily transport
*/
initializeClientAndTransport() {
async initializeClientAndTransport() {
// Initialize the RTVI client with a DailyTransport and our configuration
this.rtviClient = new RTVIClient({
transport: new DailyTransport(),
Expand Down Expand Up @@ -121,6 +170,10 @@ class ChatbotClient {
onMessageError: (error) => {
console.log('Message error:', error);
},
onMicUpdated: (data) => {
console.log('Mic updated:', data);
this.deviceSelector.value = data.deviceId;
},
onError: (error) => {
console.log('Error:', JSON.stringify(error));
},
Expand All @@ -129,6 +182,52 @@ class ChatbotClient {

// Set up listeners for media track events
this.setupTrackListeners();

await this.rtviClient.initDevices();
this.setupEventListeners();

let llmHelper = new LLMHelper({});
llmHelper.handleFunctionCall(async (data) => {
return await this.handleFunctionCall(data.functionName, data.arguments);
});
this.rtviClient.registerHelper('openai', llmHelper);
}

async handleFunctionCall(functionName, args) {
console.log('[EVENT] LLMFunctionCall', functionName, args);
const toolFunctions = {
changeBackgroundColor: ({ color }) => {
console.log('changing background color to', color);
document.body.style.backgroundColor = color;
return { success: true, color };
},
get_current_weather: async (data) => {
console.log('getting weather for', data, data.location);
const key = import.meta.env.VITE_DANGEROUS_OPENWEATHER_API_KEY;
if (!key) {
const ret = { success: true, weather: _generateRandomWeather() };
console.log('returning weather', ret);
return ret;
}
const locationReq = await fetch(
`http://api.openweathermap.org/geo/1.0/direct?q=${location}&limit=1&appid=${key}`
);
const locJson = await locationReq.json();
const loc = { lat: locJson[0].lat, lon: locJson[0].lon };
const exclude = ['minutely', 'hourly', 'daily'].join(',');
const weatherRec = await fetch(
`https://api.openweathermap.org/data/3.0/onecall?lat=${loc.lat}&lon=${loc.lon}&exclude=${exclude}&appid=${key}`
);
const weather = await weatherRec.json();
return { success: true, weather: weather.current };
},
};
const toolFunction = toolFunctions[functionName];
if (toolFunction) {
let result = await toolFunction(args);
console.debug('returning result', result);
return result;
}
}

/**
Expand Down
41 changes: 39 additions & 2 deletions examples/simple-chatbot/client/javascript/src/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ body {
margin: 0 auto;
}

.status-bar {
.status-bar,
.device-bar {
display: flex;
justify-content: space-between;
align-items: center;
Expand All @@ -20,14 +21,47 @@ body {
margin-bottom: 20px;
}

.controls button {
.controls,
.device-controls {
display: flex;
align-items: center;
gap: 10px; /* Adds spacing between elements */
}

.device-controls {
margin-left: auto;
}

.controls button,
.device-controls button {
padding: 8px 16px;
margin-left: 10px;
border: none;
border-radius: 4px;
cursor: pointer;
}

#bot-selector,
#device-selector {
padding: 8px 16px;
padding-right: 40px;
border: none;
border-radius: 4px;
background-color: #6c757d; /* Gray background */
color: white; /* White text */
cursor: pointer;
appearance: none; /* Removes default browser styling for dropdowns */
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M7 10l5 5 5-5z'/%3E%3C/svg%3E"); /* Custom arrow */
background-repeat: no-repeat;
background-position: right 8px center; /* Position the arrow */
}

#bot-selector:focus,
#device-selector:focus {
outline: none;
box-shadow: 0 0 4px rgba(0, 0, 0, 0.3); /* Add a subtle focus effect */
}

#connect-btn {
background-color: #4caf50;
color: white;
Expand All @@ -38,6 +72,9 @@ body {
color: white;
}

#mic-toggle-btn {
}

button:disabled {
opacity: 0.5;
cursor: not-allowed;
Expand Down
80 changes: 74 additions & 6 deletions examples/simple-chatbot/server/bot-openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@
import asyncio
import os
import sys
from typing import Dict

import aiohttp
from dotenv import load_dotenv
from loguru import logger
from PIL import Image
from runner import configure

from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
FunctionCallResultFrame,
OutputImageRawFrame,
SpriteFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
Expand All @@ -42,6 +47,7 @@
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport

Expand Down Expand Up @@ -69,6 +75,49 @@
quiet_frame = sprites[0] # Static frame for when bot is listening
talking_frame = SpriteFrame(images=sprites) # Animation sequence for when bot is talking

#
# RTVI events for Pipecat client UI
#
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))


class WeatherProcessor(FrameProcessor):
"""Processes weather-related function calls.

This processor handles the function call to fetch weather data and
manages the response.
"""

# currently does nothing but tracks waiting calls but could be used
waiting_calls: Dict[str, FunctionCallParams] = {}

def __init__(self):
super().__init__()

async def fetch_weather(self, params: FunctionCallParams):
print("Fetching weather data...", params)
await params.llm.push_frame(TTSSpeakFrame("Let me check on that."))
await rtvi.handle_function_call(params)
self.waiting_calls[params.tool_call_id] = params

async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle function calls.

Args:
frame: The incoming frame to process
direction: The direction of frame flow in the pipeline
"""
await super().process_frame(frame, direction)

if isinstance(frame, FunctionCallResultFrame):
print("Function call result:", frame.tool_call_id, frame.result)
if "weather" in frame.result and "condition" in frame.result["weather"]:
frame.result["weather"]["condition"] = "hazy"
if frame.tool_call_id in self.waiting_calls:
del self.waiting_calls[frame.tool_call_id]

await self.push_frame(frame, direction)


class TalkingAnimation(FrameProcessor):
"""Manages the bot's visual animation states.
Expand Down Expand Up @@ -157,6 +206,29 @@ async def main():
# Initialize LLM service
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))

# Set up function calling
wp = WeatherProcessor()
llm.register_function("get_current_weather", wp.fetch_weather)
# llm.register_function("get_current_weather", fetch_weather_from_api)

weather_function = FunctionSchema(
name="get_current_weather",
description="Get the current weather",
properties={
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location.",
},
},
required=["format"],
)
tools = ToolsSchema(standard_tools=[weather_function])

messages = [
{
"role": "system",
Expand All @@ -173,16 +245,11 @@ async def main():

# Set up conversation context and management
# The context_aggregator will automatically collect conversation context
context = OpenAILLMContext(messages)
context = OpenAILLMContext(messages, tools)
context_aggregator = llm.create_context_aggregator(context)

ta = TalkingAnimation()

#
# RTVI events for Pipecat client UI
#
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))

pipeline = Pipeline(
[
transport.input(),
Expand All @@ -191,6 +258,7 @@ async def main():
llm,
tts,
ta,
wp,
transport.output(),
context_aggregator.assistant(),
]
Expand Down
0