On-Demand Tracing in Go

Traces are amazing, but can often generate huge amounts of overhead traffic whenever we want to track request heavy services.

By default otel gives us three sampling methods, AlwaysSample, NeverSample and TraceIDRatioBased, with the first two methods being quite self-explanatory, and the third having the following documentation.

// TraceIDRatioBased samples a given fraction of traces. Fractions >= 1 will
// always sample. Fractions < 0 are treated as zero. To respect the
// parent trace's `SampledFlag`, the `TraceIDRatioBased` sampler should be used
// as a delegate of a `Parent` sampler.
func TraceIDRatioBased(fraction float64) Sampler { ... }

If sampling only a fraction of incoming requests doesn’t satisfy you, I will propose a way for the client to decide whenever a request should be traced or not by appending a custom header in either HTTP or GRPC requests.

Implementing Our Custom Sampler

type ctxKey struct {
	string
}

var ShouldSampleContextKey = &ctxKey{"should sample"}

// onDemandSampler implements trace.Sampler
type onDemandSampler struct{}

func (s onDemandSampler) ShouldSample(p sdktrace.SamplingParameters) sdktrace.SamplingResult {
	v, ok := p.ParentContext.Value(ShouldSampleContextKey).(bool)
	if v && ok {
		return sdktrace.SamplingResult{
			Decision:   sdktrace.RecordAndSample,
			Tracestate: trace.SpanContextFromContext(p.ParentContext).TraceState(),
		}
	}
	return sdktrace.SamplingResult{
		Decision:   sdktrace.Drop,
		Tracestate: trace.SpanContextFromContext(p.ParentContext).TraceState(),
	}
}

func (s onDemandSampler) Description() string {
	return "OnDemandSampler"
}

Instrumentation

GRPC Instrumentation

const SHOULD_SAMPLE_HEADER_NAME = "x-trace"
func NewSamplerOnDemandUnaryInterceptor(tracer trace.Tracer) grpc.UnaryServerInterceptor {
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		md, ok := metadata.FromIncomingContext(ctx)
		if ok && lo.Contains(md.Get(SHOULD_SAMPLE_HEADER_NAME), "true") {
			ctx = context.WithValue(ctx, tracing.ShouldSampleContextKey, true)
			ctx, span := tracer.Start(ctx, info.FullMethod)
			defer span.End()
			return handler(ctx, req)
		}
		return handler(ctx, req)
	}
}
func NewSamplerOnDemandStreamInterceptor(tracer trace.Tracer) grpc.StreamServerInterceptor {
	return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		ctx := ss.Context()
		md, ok := metadata.FromIncomingContext(ctx)
		if ok && lo.Contains(md.Get(SHOULD_SAMPLE_HEADER_NAME), "true") {
			ctx = context.WithValue(ctx, tracing.ShouldSampleContextKey, true)
			ctx, span := tracer.Start(ctx, info.FullMethod)
			defer span.End()
			wrapped := middleware.WrapServerStream(ss)
			wrapped.WrappedContext = ctx
			return handler(srv, wrapped)
		}
		return handler(srv, ss)
	}
}

HTTP Instrumentation

func NewSamplerOnDemandHttpMiddleware(tracer trace.Tracer) echo.MiddlewareFunc {
	return func(fn echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			r := c.Request()
			ctx := r.Context()
			if r.Header.Get(SHOULD_SAMPLE_HEADER_NAME) == "true" {
                ctx = context.WithValue(ctx, tracing.ShouldSampleContextKey, true)
				ctx, span := tracer.Start(ctx, r.URL.Path)
				c.SetRequest(r.WithContext(ctx))
				defer span.End()
			}
			return fn(c)
		}
	}
}

Creating The Tracer

type SamplingMethod string

const (
	AlwaysSample SamplingMethod = "always"
	OnDemand     SamplingMethod = "on-demand"
	Never        SamplingMethod = "never"
)

type Config struct {
	Enabled  bool           `yaml:"enabled"`
	Sampler  SamplingMethod `yaml:"sampler"`
	Endpoint string         `yaml:"endpoint"`
}

// New returns an instance of Tracer.
func New(ctx context.Context, name string, cfg *Config) (trace.Tracer, error) {
	if !cfg.Enabled || cfg.Sampler == Never {
		return noop.NewTracerProvider().Tracer(name), nil
	}

	client := otlptracegrpc.NewClient(
		otlptracegrpc.WithEndpoint(cfg.Endpoint),
		otlptracegrpc.WithInsecure(),
	)

	exporter, err := otlptrace.New(ctx, client)
	if err != nil {
		return nil, errors.Errorf("creating OTLP trace exporter: %e", err)
	}

	bsp := sdktrace.NewBatchSpanProcessor(exporter)

	var sampler sdktrace.Sampler
	switch cfg.Sampler {
	case AlwaysSample:
		sampler = sdktrace.AlwaysSample()
	case OnDemand:
		sampler = &onDemandSampler{}
	default:
		return nil, errors.Errorf("sampling method not supported: %s", cfg.Sampler)
	}

	tp := sdktrace.NewTracerProvider(
		sdktrace.WithSampler(sampler),
		sdktrace.WithSpanProcessor(bsp),
		sdktrace.WithResource(newResource(name)),
	)

	return tp.Tracer(name), nil
}

func newResource(service string) *resource.Resource {
	return resource.NewWithAttributes(
		semconv.SchemaURL,
		semconv.ServiceName(service),
	)
}
func main() {
    ctx := context.Background()

    tracer, err := tracing.New(ctx, "foobar", &tracing.Config{
        Enabled: true,
        Sampler: tracing.OnDemand,
        Endpoint: "tempo:4317",
    })
    if err != nil {
        panic(err)
    }

    s := server.New()

    grpcServer := grpc.NewServer(
        grpc.ChainUnaryInterceptor(
            middleware.NewSamplerOnDemandUnaryInterceptor(tracer),
        ),
        grpc.ChainStreamInterceptor(
            middleware.NewSamplerOnDemandStreamInterceptor(tracer),
        ))

    pb.RegisterServiceServer(grpcServer, s)

    return grpcServer.Serve(lis)
}