diff --git a/config/client_aux.yaml b/config/client_aux.yaml index fe3e332..a82834b 100644 --- a/config/client_aux.yaml +++ b/config/client_aux.yaml @@ -1,28 +1,53 @@ huri_url: ws://localhost:8000/session -topic_list: [question] +interface_path: src.interfaces.cli_interface:cli_interface senders: audio: name: audio + topic: audio + args: + sample_rate: 16000 + frame_duration: 0.030 + text: + name: text + topic: question args: sample_rate: 16000 frame_duration: 0.030 +hooks: + audio: + name: audio + topics: [audio] + args: + incoming_sample_rate: ${senders.audio.args.sample_rate} + sample_rate: 44100 + modules: mic: name: mic args: vad_agressiveness: 3 silence_duration: 1.5 - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO + block_duration: ${senders.audio.args.frame_duration} stt: name: stt args: - language: "en" - block_duration: ${inputs.audio.args.frame_duration} - logging: INFO + language: en + block_duration: ${senders.audio.args.frame_duration} tag: name: tag - logging: INFO + emo: + name: emo + args: + block_duration: ${senders.audio.args.frame_duration} + eag: + name: eag + qag: + name: qag + rag: + name: rag + args: + language: en + tone: formal diff --git a/config/client_aux2.yaml b/config/client_aux2.yaml deleted file mode 100644 index 7d7b601..0000000 --- a/config/client_aux2.yaml +++ /dev/null @@ -1,32 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [transcript, question, rag_response] - -senders: - audio: - name: audio - args: - sample_rate: 16000 - frame_duration: 0.030 - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${senders.audio.args.frame_duration} - stt: - name: stt - args: - language: en - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO - rag: - name: rag - args: - language: en - tone: formal diff --git a/config/client_auxio.yaml b/config/client_auxio.yaml deleted file mode 100644 index 8fa2a91..0000000 --- a/config/client_auxio.yaml +++ /dev/null @@ -1,25 +0,0 @@ -huri_url: ws://localhost:8000/session - -topic_list: [question] - -senders: - text: - name: text - -modules: - mic: - name: mic - args: - vad_agressiveness: 3 - silence_duration: 1.5 - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - stt: - name: stt - args: - language: en - block_duration: ${senders.audio.args.frame_duration} - logging: INFO - tag: - name: tag - logging: INFO diff --git a/config/client_template.yaml b/config/client_template.yaml index cf1627d..441f3c5 100644 --- a/config/client_template.yaml +++ b/config/client_template.yaml @@ -1,19 +1,34 @@ # HuRI websocket server url huri_url: ws://localhost:8000/session -# List of event topic the client will receive -topic_list: [topic1, topic2] +# Define interface to be used's import path +interface_path: src.interfaces.cli_interface:cli_interface # Define senders to be used and their custom args senders: # sender tag can be anything example: - # sender name must be in the list of available ClientSender in Client instance (src.client_sender:get_senders) + # sender name must be in the list of available ClientSender in chosen Interface (Interface.get_senders) name: my_sender + # topic the sender will send to HuRI, it must match output_type event data structure + topic: my_event # if my_sender init with "model", "sample_rate" and "refresh_rate" params, they can be customized here args: refresh_rate: infinite +# Define hooks to be used and their custom args +hooks: + # hook tag can be anything + example: + # hook name must be in the list of available ClientHook in chosen Interface (Interface.get_senders) + name: my_hook + # topics the hook will process from HuRI, it must match input_type event data structure + topics: [my_event, llm_response] + # if my_hook init with "model", "sample_rate" and "refresh_rate" params, they can be customized here + args: + sample_rate: 0 + no: beat + # Define module to be used and their custom args modules: # module tag can be anything diff --git a/config/client_text.yaml b/config/client_text.yaml index 8ddcaab..d2fb26f 100644 --- a/config/client_text.yaml +++ b/config/client_text.yaml @@ -1,10 +1,16 @@ huri_url: ws://localhost:8000/session -topic_list: [question, rag_response] +interface_path: src.interfaces.cli_interface:cli_interface senders: text: name: text + topic: question + +hooks: + text: + name: text + topics: [rag_response] modules: rag: diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..6927565 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,14 +1,80 @@ import asyncio +import importlib import json import os +import struct +from collections import defaultdict from dataclasses import asdict -from typing import Dict, List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import websockets from src.core.dataclasses.config import ClientConfig +from src.core.events import EventData -from .client_senders import ClientSender, get_senders +T = TypeVar("T", bound=EventData | bytes) + + +class ClientSender(Generic[T]): + """This class abstract sending data to HuRI. + + output_type: is the event data structure that the ClientSender will send. + It can be EventData or bytes, and must match event topic it send. + + Class derived from ClientSender must implement input_loop, + and use ClientSender.send to send data to HuRI. + + `singletton` is available to access shared ressources. + """ + + output_type: Type[T] + + def __init__(self, topic: str, singletton: Any, **_): + """ + :topic: topic sent to HuRI + :singletton: allow to get shared ressources""" + self.topic = topic + self.singletton = singletton + + async def input_loop(self, ws: websockets.ClientConnection): + raise NotImplementedError + + async def _send_bytes(self, ws: websockets.ClientConnection, data: bytes): + topic_bytes = self.topic.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + + await ws.send(packet) + + async def _send_event_data(self, ws: websockets.ClientConnection, data: EventData): + packet = json.dumps({"topic": self.topic, "data": asdict(data)}) + + await ws.send(packet) + + async def send(self, ws: websockets.ClientConnection, data: T): + if isinstance(data, bytes): + await self._send_bytes(ws, data) + else: + await self._send_event_data(ws, data) + + +class ClientHook(Generic[T]): + """This class abstract processing data from HuRI. + + input_type: is the event data structure that the ClientHook will process. + It can be EventData or bytes, and must match event topic it react to. + + Class derived from ClientHook must implement hook. + + `singletton` is available to access and modifies shared ressources. + """ + + input_type: Type[T] + + def __init__(self, singletton: Any, **_): + self.singletton = singletton + + async def hook(self, data: T): + raise NotImplementedError class Client: @@ -18,11 +84,33 @@ def __init__( self, config: ClientConfig, user_id_file: str = os.path.expanduser("~/.huri_user_id"), - senders_dict: Dict[str, Type[ClientSender]] = get_senders(), ): self.config = config + + module_path, object_name = self.config.interface_path.split(":", 1) + + module = importlib.import_module(module_path) + interface = getattr(module, object_name) + + available_senders = interface.get_senders() + self.senders: List[ClientSender] = [ + available_senders[sender.name]( + topic=sender.topic, singletton=interface.singletton, **sender.args + ) + for sender in self.config.senders.values() + ] + + available_hooks = interface.get_hooks() + self.hooks: Dict[str, List[ClientHook]] = defaultdict(list) + for hook in self.config.hooks.values(): + for topic in hook.topics: + self.hooks[topic].append( + available_hooks[hook.name]( + singletton=interface.singletton, **hook.args + ) + ) + self.user_id_file = user_id_file - self.senders_dict = senders_dict def _load_user_id(self) -> Optional[str]: if os.path.exists(self.user_id_file): @@ -37,9 +125,22 @@ def _save_user_id(self, _user_id: str): async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: - text = await ws.recv() - print("<<", text) - await asyncio.sleep(0.1) + msg = await ws.recv() + + if isinstance(msg, bytes): + topic_len = struct.unpack("!H", msg[:2])[0] + + topic = msg[2 : 2 + topic_len].decode() + data = msg[2 + topic_len :] + else: + event = json.loads(msg) + topic = event["topic"] + data = event["data"] + + for hook in self.hooks[topic]: + if not isinstance(data, bytes): + data = hook.input_type(**data) + asyncio.create_task(hook.hook(data)) except (asyncio.CancelledError, websockets.ConnectionClosedOK): pass @@ -50,11 +151,6 @@ async def run(self): self.config.user_id = self._load_user_id() - senders: List[ClientSender] = [ - self.senders_dict[config.name](ws=ws, **config.args) - for config in self.config.senders.values() - ] - await ws.send(json.dumps(asdict(self.config))) init_msg = json.loads(await ws.recv()) @@ -63,9 +159,9 @@ async def run(self): self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") - receive_task = asyncio.create_task(self._receive_loop(ws)) + receive_task = asyncio.create_task(self._receive_loop(ws=ws)) await asyncio.gather( - *(sender.input_loop() for sender in senders), + *(sender.input_loop(ws=ws) for sender in self.senders), ) receive_task.cancel() diff --git a/src/core/client_senders.py b/src/core/client_senders.py deleted file mode 100644 index 03301a6..0000000 --- a/src/core/client_senders.py +++ /dev/null @@ -1,102 +0,0 @@ -import asyncio -import json -import struct -from dataclasses import asdict -from typing import Dict, Type - -import numpy as np -import sounddevice as sd -import websockets -from prompt_toolkit import PromptSession -from prompt_toolkit.patch_stdout import patch_stdout - -from src.core.events import EventData -from src.modules.speech_to_text.events import Sentence - - -class ClientSender: - """This class abstract sending data to HuRI. - - output_type: is the topic that the ClientSender will send. - Data structure must match event topic. - - Class derived from ClientSender must implement input_loop, - and use ClientSender.send to send data to HuRI. It can be EventData or bytes - """ - - output_type: str - - def __init__(self, ws: websockets.ClientConnection): - self.ws = ws - - async def input_loop(self): - raise NotImplementedError - - async def send(self, topic: str, data: EventData | bytes): - packet: str | bytes - if isinstance(data, EventData): - packet = json.dumps({"topic": topic, "data": asdict(data)}) - else: - topic_bytes = topic.encode() - - packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data - - await self.ws.send(packet) - - -class AudioSender(ClientSender): - output_type = "audio" - - def __init__( - self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs - ): - super().__init__(**kwargs) - - self.sample_rate = sample_rate - self.frame_size = int(sample_rate * frame_duration) - - async def input_loop(self): - loop = asyncio.get_running_loop() - - queue: asyncio.Queue[np.ndarray] = asyncio.Queue() - - def callback(indata: np.ndarray, frames, time, status): - loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) - - with sd.InputStream( - samplerate=self.sample_rate, - channels=1, - dtype="int16", - callback=callback, - blocksize=self.frame_size, - ): - while True: - chunk = await queue.get() - await self.send(self.output_type, chunk.tobytes()) - - -class TextSender(ClientSender): - output_type = "question" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - async def input_loop(self): - print("'\\exit' or CTRL+D/C to exit.") - session: PromptSession = PromptSession() - try: - while True: - with patch_stdout(): - text = await session.prompt_async(">> ") - if text == "\\exit": - return - await self.send(self.output_type, Sentence(text)) - - except (EOFError, KeyboardInterrupt): - pass - finally: - print("TextSender Exited...") - - -def get_senders() -> Dict[str, Type[ClientSender]]: - return {"audio": AudioSender, "text": TextSender} diff --git a/src/core/dataclasses/config.py b/src/core/dataclasses/config.py index aea111f..f515026 100644 --- a/src/core/dataclasses/config.py +++ b/src/core/dataclasses/config.py @@ -15,15 +15,32 @@ def from_dict(self, raw: dict) -> "ModuleConfig": ) +@dataclass +class ClientHookConfig: + name: str + topics: List[str] + args: Mapping[str, Any] + + @classmethod + def from_dict(self, raw: dict) -> "ClientHookConfig": + return self( + name=raw["name"], + topics=raw["topics"], + args=raw.get("args", {}), + ) + + @dataclass class ClientSenderConfig: name: str + topic: str args: Mapping[str, Any] @classmethod def from_dict(self, raw: dict) -> "ClientSenderConfig": return self( name=raw["name"], + topic=raw["topic"], args=raw.get("args", {}), ) @@ -32,15 +49,20 @@ def from_dict(self, raw: dict) -> "ClientSenderConfig": class ClientConfig: user_id: Optional[str] huri_url: str - topic_list: List[str] + interface_path: str + hooks: Dict[str, ClientHookConfig] senders: Dict[str, ClientSenderConfig] modules: Dict[str, ModuleConfig] @classmethod def from_dict(cls, raw: Dict) -> "ClientConfig": + hooks = { + hook_id: ClientHookConfig.from_dict(hok_raw) + for hook_id, hok_raw in raw.get("hooks", {}).items() + } senders = { - sender_id: ClientSenderConfig.from_dict(mod_raw) - for sender_id, mod_raw in raw.get("senders", {}).items() + sender_id: ClientSenderConfig.from_dict(snd_raw) + for sender_id, snd_raw in raw.get("senders", {}).items() } modules = { module_id: ModuleConfig.from_dict(mod_raw) @@ -49,7 +71,8 @@ def from_dict(cls, raw: Dict) -> "ClientConfig": return cls( user_id=None, huri_url=raw["huri_url"], - topic_list=raw["topic_list"], + interface_path=raw["interface_path"], + hooks=hooks, senders=senders, modules=modules, ) diff --git a/src/core/huri.py b/src/core/huri.py index 5fa8038..f5e4eeb 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -32,7 +32,7 @@ def __init__( self, modules: Dict[str, Type[Module]], handles: Dict[str, handle.DeploymentHandle], - events: Dict[str, Type[EventData]], + events: Dict[str, Type[EventData | bytes]], ) -> None: self.module_factory = ModuleFactory(handles) self.event_factory = EventDataFactory() @@ -80,9 +80,12 @@ async def run_session(self, ws: WebSocket): user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - senders: List[Module] = [ - Sender(ws, topic) for topic in client_config.topic_list + topic_list = [ + topic + for hook_config in client_config.hooks.values() + for topic in hook_config.topics ] + senders: List[Module] = [Sender(ws, topic) for topic in topic_list] modules: List[Module] = ( self.module_factory.create_from_config(user_id, client_config.modules) + senders @@ -112,7 +115,7 @@ async def receive_loop(session: Session, ws: WebSocket): msg_text = msg["text"] event = json.loads(msg_text) topic = event["topic"] - data = event["data"] + data = event["data"] # TODO client/server one function data = self.event_factory.create(topic, data) diff --git a/src/core/interface.py b/src/core/interface.py new file mode 100644 index 0000000..fb2b7ac --- /dev/null +++ b/src/core/interface.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, Type + +from .client import ClientHook, ClientSender + + +class Interface: + """This class abstract defining specific Client senders and hooks. + + `self.singletton`: allow hooks to modifies shared ressources, + and comes from the used interface. + + Class derived from Interface must implement get_senders and get_hooks. + """ + + def __init__(self, singletton: Any): + self.singletton = singletton + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + raise NotImplementedError + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + raise NotImplementedError diff --git a/src/interfaces/__init__.py b/src/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/interfaces/cli_interface.py b/src/interfaces/cli_interface.py new file mode 100644 index 0000000..212f4ff --- /dev/null +++ b/src/interfaces/cli_interface.py @@ -0,0 +1,124 @@ +import asyncio +from typing import Dict, Type + +import numpy as np +import sounddevice as sd +from prompt_toolkit import PromptSession +from prompt_toolkit.patch_stdout import patch_stdout +from scipy.signal import resample + +from src.core.client import ClientHook, ClientSender +from src.core.interface import Interface +from src.modules.rag.events import RAGResult +from src.modules.speech_to_text.events import Transcript + + +class AudioSender(ClientSender[bytes]): + def __init__( + self, sample_rate: int = 16000, frame_duration: float = 0.030, **kwargs + ): + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.frame_size = int(sample_rate * frame_duration) + + async def input_loop(self, ws): + loop = asyncio.get_running_loop() + + queue: asyncio.Queue[np.ndarray] = asyncio.Queue() + + def callback(indata: np.ndarray, frames, time, status): + loop.call_soon_threadsafe(queue.put_nowait, indata.copy()) + + with sd.InputStream( + samplerate=self.sample_rate, + channels=1, + dtype="int16", + callback=callback, + blocksize=self.frame_size, + ): + while True: + chunk = await queue.get() + await self.send(ws, chunk.tobytes()) + + +class TextSender(ClientSender[Transcript]): + output_type = Transcript + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def input_loop(self, ws): + print("'\\exit' or CTRL+D/C to exit.") + session: PromptSession = PromptSession() + try: + while True: + with patch_stdout(): + text = await session.prompt_async(">> ") + if text == "\\exit": + return + await self.send(ws, Transcript(text, end=True)) + + except (EOFError, KeyboardInterrupt): + pass + finally: + print("TextSender Exited...") + + +class AudioHook(ClientHook[bytes]): + input_type = bytes + + def __init__(self, sample_rate=48000, incoming_sample_rate=16000, **kwargs): + super().__init__(**kwargs) + + print("Speaker:", sd.query_devices(kind="output")) + + self.incoming_sample_rate = incoming_sample_rate + self.sample_rate = sample_rate + self.stream = sd.OutputStream( + samplerate=sample_rate, + channels=1, + dtype="int16", + ) + self.stream.start() + + self.resample_function = ( + self._resample if sample_rate != incoming_sample_rate else lambda x: x + ) + + def _resample(self, audio: np.ndarray): + return resample( + audio, + int(len(audio) * self.sample_rate / self.incoming_sample_rate), + ).astype(np.int16) + + async def hook(self, data: bytes): + audio = np.frombuffer(data, dtype=np.int16) + + audio = self.resample_function(audio) + + self.stream.write(audio.reshape(-1, 1)) + + +class TextHook(ClientHook[RAGResult]): + input_type = RAGResult + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def hook(self, data: RAGResult): + print("<<", data.answer) + + +class CLIInterface(Interface): + def __init__(self): + super().__init__(singletton=None) + + def get_senders(self) -> Dict[str, Type[ClientSender]]: + return {"audio": AudioSender, "text": TextSender} + + def get_hooks(self) -> Dict[str, Type[ClientHook]]: + return {"audio": AudioHook, "text": TextHook} + + +cli_interface = CLIInterface() diff --git a/src/modules/__init__.py b/src/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/emotion/__init__.py b/src/modules/emotion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/emotion/emotion_aggregator.py b/src/modules/emotion/emotion_aggregator.py new file mode 100644 index 0000000..f3144fa --- /dev/null +++ b/src/modules/emotion/emotion_aggregator.py @@ -0,0 +1,73 @@ +from collections import defaultdict +from typing import Dict, Optional + +from src.core.module import Module +from src.modules.rag.events import PartialQuestion + +from .events import Emotion + + +class EAG(Module): + """EAG Module + + Aggregate all emotions and send when voice end. + + input: emotion, + output: partial_question + + :ema_alpha: if not None, the aggragation will use ema computation instead + of average. Recents emotion will have stronger impact on the final score. + Lower alpha will make impact lower, and higher alpha will make it higher. \ + Default alpha would be ~0.3. + """ + + input_type = "emotion" + output_type = "partial_question" + + def __init__(self, ema_alpha: Optional[float] = None): + super().__init__() + + self.scores: Dict[str, float] = defaultdict(float) + self.count: int = 0 + + self.ema_alpha = ema_alpha + + def _finalize(self) -> Emotion: + avg_scores = ( + {label: score / self.count for label, score in self.scores.items()} + if self.ema_alpha is None + else self.scores + ) + + best_label = max(avg_scores, key=lambda label: avg_scores[label]) + + result = Emotion( + label=best_label, + confidence=avg_scores[best_label], + scores=avg_scores, + end=True, + ) + + self.scores.clear() + self.count = 0 + + return result + + async def process(self, emotion: Emotion) -> Optional[PartialQuestion]: + if self.ema_alpha is not None: + for label, score in emotion.scores.items(): + self.scores[label] = ( + self.ema_alpha * score + (1 - self.ema_alpha) * self.scores[label] + ) + else: + for label, score in emotion.scores.items(): + self.scores[label] += score + + self.count += 1 + + if emotion.end: + emotion = self._finalize() + + return PartialQuestion(transcript=None, emotion=emotion) + + return None diff --git a/src/modules/emotion/events.py b/src/modules/emotion/events.py new file mode 100644 index 0000000..4c29c37 --- /dev/null +++ b/src/modules/emotion/events.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Dict + +from src.core.events import EventData + + +@dataclass +class Emotion(EventData): + label: str + confidence: float + scores: Dict[str, float] + end: bool diff --git a/src/modules/emotion/prosody_analysis.py b/src/modules/emotion/prosody_analysis.py new file mode 100644 index 0000000..cd50b3d --- /dev/null +++ b/src/modules/emotion/prosody_analysis.py @@ -0,0 +1,109 @@ +import asyncio +from typing import List, Optional + +import numpy as np +import torch +from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor + +from src.core.module import Module +from src.modules.speech_to_text.events import Voice + +from .events import Emotion + + +class EMO(Module): + """EMO Module + + Prosody Analysis of user voice speech. + + input: voice, + output: emotion + + :model_name: name of the Emotion Analysis model. + :sample_rate: size of received voice audio. Usually 8000, 16000 or 48000. + :block_duration: size of received voice audio (in s). + :analysis_window: duration of audio per analysis (in s). + """ + + input_type = "voice" + output_type = "emotion" + + def __init__( + self, + model_name: str = "superb/hubert-large-superb-er", + sample_rate: int = 16000, + block_duration: float = 0.020, # s + analysis_window: float = 4.0, # s + ): + super().__init__() + + self.model = AutoModelForAudioClassification.from_pretrained(model_name) + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) + + self.sample_rate = sample_rate + self.window_size = int(analysis_window / block_duration) + + self.buffer: List[np.ndarray] = [] + + self.silence: bool = True + + self.running = False + self.lock: asyncio.Lock = asyncio.Lock() + + def _predict_emotion(self, audio_np: np.ndarray): + inputs = self.feature_extractor( + audio_np, sampling_rate=self.sample_rate, return_tensors="pt", padding=True + ) + + with torch.no_grad(): + logits = self.model(**inputs).logits + probs = torch.softmax(logits, dim=-1)[0] + + predicted_id = int(torch.argmax(probs).item()) + + labels = self.model.config.id2label + + return { + "label": labels[predicted_id], + "confidence": float(probs[predicted_id]), + "scores": {labels[i]: float(probs[i]) for i in range(len(labels))}, + } + + async def process(self, voice: Voice) -> Optional[Emotion]: + if voice.data is None: + self.silence = True + else: + self.silence = False + async with self.lock: + self.buffer.append(voice.data) + + async with self.lock: + if self.running: + return None + self.running = True + + async with self.lock: + buffer_size = len(self.buffer) + if buffer_size == 0 or ( + self.silence is False and buffer_size < self.window_size + ): + self.running = False + return None + processing_chunks = self.buffer[: self.window_size] + + processing_audio = np.concatenate(processing_chunks, axis=0) + + emotion_result = await asyncio.to_thread( + self._predict_emotion, audio_np=processing_audio + ) + + async with self.lock: + self.buffer = self.buffer[self.window_size :] + self.running = False + + return Emotion( + emotion_result["label"], + emotion_result["confidence"], + emotion_result["scores"], + self.silence, + ) diff --git a/src/modules/events.py b/src/modules/events.py index 43f6c71..387bdab 100644 --- a/src/modules/events.py +++ b/src/modules/events.py @@ -1,8 +1,9 @@ from typing import Dict, Type from src.core.events import EventData -from src.modules.rag.events import RAGResult -from src.modules.speech_to_text.events import Sentence, Transcript, Voice +from src.modules.emotion.events import Emotion +from src.modules.rag.events import PartialQuestion, RAGQuestion, RAGResult +from src.modules.speech_to_text.events import Transcript, Voice def get_events() -> Dict[str, Type[EventData | bytes]]: @@ -10,6 +11,8 @@ def get_events() -> Dict[str, Type[EventData | bytes]]: "audio": bytes, "voice": Voice, "transcript": Transcript, - "question": Sentence, + "emotion": Emotion, + "partial_question": PartialQuestion, + "question": RAGQuestion, "rag_response": RAGResult, } diff --git a/src/modules/modules.py b/src/modules/modules.py index 8fbc53c..ce60461 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -1,5 +1,8 @@ from typing import Dict, Type +from src.modules.emotion.emotion_aggregator import EAG +from src.modules.emotion.prosody_analysis import EMO +from src.modules.rag.question_aggregator import QAG from src.modules.rag.rag import RAG from src.modules.speech_to_text.microphone_vad import MIC from src.modules.speech_to_text.speech_to_text import STT @@ -9,4 +12,12 @@ def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} + return { + "mic": MIC, + "stt": STT, + "tag": TAG, + "emo": EMO, + "eag": EAG, + "qag": QAG, + "rag": RAG, + } diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/rag/events.py b/src/modules/rag/events.py index 5d237d2..24f883f 100644 --- a/src/modules/rag/events.py +++ b/src/modules/rag/events.py @@ -1,6 +1,9 @@ from dataclasses import dataclass, field +from typing import Optional from src.core.events import EventData +from src.modules.emotion.events import Emotion +from src.modules.speech_to_text.events import Transcript @dataclass @@ -9,3 +12,19 @@ class RAGResult(EventData): answer: str sources: list[dict] = field(default_factory=list) + + +@dataclass +class PartialQuestion(EventData): + """Partial question used to aggregate a sentence to an emotion.""" + + transcript: Optional[Transcript] + emotion: Optional[Emotion] + + +@dataclass +class RAGQuestion(EventData): + """Fully aggregated question to send to the RAG.""" + + transcript: Transcript + emotion: Optional[Emotion] diff --git a/src/modules/rag/question_aggregator.py b/src/modules/rag/question_aggregator.py new file mode 100644 index 0000000..8fd864b --- /dev/null +++ b/src/modules/rag/question_aggregator.py @@ -0,0 +1,47 @@ +from typing import Optional + +from src.core.module import Module +from src.modules.emotion.events import Emotion +from src.modules.speech_to_text.events import Transcript + +from .events import PartialQuestion, RAGQuestion + + +class QAG(Module): + """QAG Module + + Aggregate sentence and emotion into a RAGQuestion. + + input: partial_question, + output: question + + :use_emotion: default True. + Set to False if you do not analyze emotion or do not need it. + """ + + input_type = "partial_question" + output_type = "question" + + def __init__(self, use_emotion: bool = True): + super().__init__() + + self.current_transcript: Optional[Transcript] = None + self.current_emotion: Optional[Emotion] = None + + self.use_emotion = use_emotion + + async def process(self, partial_question: PartialQuestion) -> Optional[RAGQuestion]: + if partial_question.emotion is not None: + self.current_emotion = partial_question.emotion + + if partial_question.transcript is not None: + self.current_transcript = partial_question.transcript + + if self.current_transcript is not None: + if self.use_emotion: + if self.current_emotion is not None: + return RAGQuestion(self.current_transcript, self.current_emotion) + else: + return RAGQuestion(self.current_transcript, None) + + return None diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 6b9744d..e14e7c8 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -9,9 +9,8 @@ from sentence_transformers import SentenceTransformer from src.core.module import ModuleWithHandle, ModuleWithId -from src.modules.speech_to_text.events import Sentence -from .events import RAGResult +from .events import RAGQuestion, RAGResult @dataclass @@ -300,12 +299,12 @@ def __init__( "extra_instructions": extra_instructions, } - async def process(self, data: Sentence) -> Optional[RAGResult]: + async def process(self, data: RAGQuestion) -> Optional[RAGResult]: """ Called when a "question" event arrives through the event bus. Packages _user_id + question, sends to the stateless RAGHandle. """ - question_text = data.text + question_text = data.transcript.text query = RAGQuery( _user_id=self._user_id if self._user_id else "anonymous", diff --git a/src/modules/speech_to_text/events.py b/src/modules/speech_to_text/events.py index e80f522..fc04674 100644 --- a/src/modules/speech_to_text/events.py +++ b/src/modules/speech_to_text/events.py @@ -15,8 +15,3 @@ class Transcript(EventData): @dataclass class Voice(EventData): data: Optional[np.ndarray] - - -@dataclass -class Sentence(EventData): - text: str diff --git a/src/modules/speech_to_text/speech_to_text.py b/src/modules/speech_to_text/speech_to_text.py index 1300dd3..6531e87 100644 --- a/src/modules/speech_to_text/speech_to_text.py +++ b/src/modules/speech_to_text/speech_to_text.py @@ -25,6 +25,8 @@ class STT(Module): as "en" or "fr". :sample_rate: size of received voice audio. Usually 8000, 16000 or 48000. :block_duration: size of received voice audio (in s). + :transcribe_window: duration of audio per transcription (in s). + :transcribe_step: overlap between consecutive transcription windows (in s). """ input_type = "voice" @@ -52,9 +54,6 @@ def __init__( self.silence: bool = True - self.prev_text: str = "" - self.stable_text: str = "" - self.running = False self.lock: asyncio.Lock = asyncio.Lock() @@ -80,16 +79,17 @@ async def process(self, voice: Voice) -> Optional[Transcript]: return None processing_chunks = self.buffer[: self.window_size] - self.pending_silence = False processing_audio = np.concatenate(processing_chunks, axis=0) - segments, _ = self.model_faster.transcribe( - processing_audio, - language=self.language, - beam_size=1, # faster for realtime - ) + def transcribe_text(): + segments, _ = self.model_faster.transcribe( + processing_audio, + language=self.language, + beam_size=1, + ) + return " ".join(seg.text for seg in segments).strip() - current_text = " ".join([seg.text for seg in segments]).strip() + current_text = await asyncio.to_thread(transcribe_text) processed_size = self.window_size - self.step_size async with self.lock: diff --git a/src/modules/speech_to_text/text_aggregator.py b/src/modules/speech_to_text/text_aggregator.py index 72760af..b9ef1ce 100644 --- a/src/modules/speech_to_text/text_aggregator.py +++ b/src/modules/speech_to_text/text_aggregator.py @@ -2,8 +2,9 @@ from typing import Optional from src.core.module import Module +from src.modules.rag.events import PartialQuestion -from .events import Sentence, Transcript +from .events import Transcript class TAG(Module): @@ -12,11 +13,11 @@ class TAG(Module): Aggregate all transcriptions and send when transcript end. input: transcript, - output: question + output: partial_question """ input_type = "transcript" - output_type = "question" + output_type = "partial_question" def __init__( self, @@ -36,7 +37,7 @@ def _merge(self, current: str, new: str) -> str: self.prev_index = len(current) return current + new - async def process(self, transcript: Transcript) -> Optional[Sentence]: + async def process(self, transcript: Transcript) -> Optional[PartialQuestion]: text = transcript.text if text != "": @@ -46,9 +47,9 @@ async def process(self, transcript: Transcript) -> Optional[Sentence]: self.sentence = self._merge(self.sentence, text) if transcript.end and self.sentence != "": - sentence = Sentence(self.sentence) + transcript = Transcript(self.sentence, True) self.sentence = "" self.prev_index = 0 - return sentence + return PartialQuestion(transcript=transcript, emotion=None) else: return None diff --git a/src/modules/utils/sender.py b/src/modules/utils/sender.py index f09b0ba..a9fc2fa 100644 --- a/src/modules/utils/sender.py +++ b/src/modules/utils/sender.py @@ -1,3 +1,4 @@ +import struct from dataclasses import asdict from fastapi import WebSocket @@ -23,8 +24,8 @@ def __init__(self, ws: WebSocket, type: str): async def process(self, data: EventData | bytes): if isinstance(data, bytes): - await self.ws.send_bytes(data) - elif isinstance(data, EventData): - await self.ws.send_json(asdict(data)) + topic_bytes = self.input_type.encode() + packet = struct.pack("!H", len(topic_bytes)) + topic_bytes + data + await self.ws.send_bytes(packet) else: - await self.ws.send_text(data) + await self.ws.send_json({"topic": self.input_type, "data": asdict(data)})